diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 83a1578640699..ee7e00e3a3fbf 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -402,6 +402,8 @@ jobs:
runs-on: ubuntu-latest
permissions:
packages: write
+ env:
+ DOCKER_BUILD_RECORD_UPLOAD: false
steps:
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
@@ -1240,9 +1242,9 @@ jobs:
sudo apt update
sudo apt-get install r-base
- name: Start Minikube
- uses: medyagh/setup-minikube@v0.0.18
+ uses: medyagh/setup-minikube@v0.0.19
with:
- kubernetes-version: "1.32.0"
+ kubernetes-version: "1.33.0"
# Github Action limit cpu:2, memory: 6947MB, limit to 2U6G for better resource statistic
cpus: 2
memory: 6144m
diff --git a/.github/workflows/build_infra_images_cache.yml b/.github/workflows/build_infra_images_cache.yml
index ac139147beb91..ccd47826ff099 100644
--- a/.github/workflows/build_infra_images_cache.yml
+++ b/.github/workflows/build_infra_images_cache.yml
@@ -30,12 +30,17 @@ on:
- 'dev/spark-test-image/docs/Dockerfile'
- 'dev/spark-test-image/lint/Dockerfile'
- 'dev/spark-test-image/sparkr/Dockerfile'
+ - 'dev/spark-test-image/python-minimum/Dockerfile'
+ - 'dev/spark-test-image/python-ps-minimum/Dockerfile'
- 'dev/spark-test-image/pypy-310/Dockerfile'
- 'dev/spark-test-image/python-309/Dockerfile'
- 'dev/spark-test-image/python-310/Dockerfile'
- 'dev/spark-test-image/python-311/Dockerfile'
+ - 'dev/spark-test-image/python-311-classic-only/Dockerfile'
- 'dev/spark-test-image/python-312/Dockerfile'
- 'dev/spark-test-image/python-313/Dockerfile'
+ - 'dev/spark-test-image/python-313-nogil/Dockerfile'
+ - 'dev/spark-test-image/numpy-213/Dockerfile'
- '.github/workflows/build_infra_images_cache.yml'
# Create infra image when cutting down branches/tags
create:
@@ -187,6 +192,19 @@ jobs:
- name: Image digest (PySpark with Python 3.11)
if: hashFiles('dev/spark-test-image/python-311/Dockerfile') != ''
run: echo ${{ steps.docker_build_pyspark_python_311.outputs.digest }}
+ - name: Build and push (PySpark Classic Only with Python 3.11)
+ if: hashFiles('dev/spark-test-image/python-311-classic-only/Dockerfile') != ''
+ id: docker_build_pyspark_python_311_classic_only
+ uses: docker/build-push-action@v6
+ with:
+ context: ./dev/spark-test-image/python-311-classic-only/
+ push: true
+ tags: ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-311-classic-only-cache:${{ github.ref_name }}-static
+ cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-311-classic-only-cache:${{ github.ref_name }}
+ cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-311-classic-only-cache:${{ github.ref_name }},mode=max
+ - name: Image digest (PySpark Classic Only with Python 3.11)
+ if: hashFiles('dev/spark-test-image/python-311-classic-only/Dockerfile') != ''
+ run: echo ${{ steps.docker_build_pyspark_python_311_classic_only.outputs.digest }}
- name: Build and push (PySpark with Python 3.12)
if: hashFiles('dev/spark-test-image/python-312/Dockerfile') != ''
id: docker_build_pyspark_python_312
@@ -213,3 +231,29 @@ jobs:
- name: Image digest (PySpark with Python 3.13)
if: hashFiles('dev/spark-test-image/python-313/Dockerfile') != ''
run: echo ${{ steps.docker_build_pyspark_python_313.outputs.digest }}
+ - name: Build and push (PySpark with Python 3.13 no GIL)
+ if: hashFiles('dev/spark-test-image/python-313-nogil/Dockerfile') != ''
+ id: docker_build_pyspark_python_313_nogil
+ uses: docker/build-push-action@v6
+ with:
+ context: ./dev/spark-test-image/python-313-nogil/
+ push: true
+ tags: ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-313-nogil-cache:${{ github.ref_name }}-static
+ cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-313-nogil-cache:${{ github.ref_name }}
+ cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-313-nogil-cache:${{ github.ref_name }},mode=max
+ - name: Image digest (PySpark with Python 3.13 no GIL)
+ if: hashFiles('dev/spark-test-image/python-313-nogil/Dockerfile') != ''
+ run: echo ${{ steps.docker_build_pyspark_python_313_nogil.outputs.digest }}
+ - name: Build and push (PySpark with Numpy 2.1.3)
+ if: hashFiles('dev/spark-test-image/numpy-213/Dockerfile') != ''
+ id: docker_build_pyspark_numpy_213
+ uses: docker/build-push-action@v6
+ with:
+ context: ./dev/spark-test-image/numpy-213/
+ push: true
+ tags: ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-numpy-213-cache:${{ github.ref_name }}-static
+ cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-numpy-213-cache:${{ github.ref_name }}
+ cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-numpy-213-cache:${{ github.ref_name }},mode=max
+ - name: Image digest (PySpark with Numpy 2.1.3)
+ if: hashFiles('dev/spark-test-image/numpy-213/Dockerfile') != ''
+ run: echo ${{ steps.docker_build_pyspark_numpy_213.outputs.digest }}
diff --git a/.github/workflows/build_maven_java21_arm.yml b/.github/workflows/build_maven_java21_arm.yml
new file mode 100644
index 0000000000000..505bdd63189c0
--- /dev/null
+++ b/.github/workflows/build_maven_java21_arm.yml
@@ -0,0 +1,37 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / Maven (master, Scala 2.13, Hadoop 3, JDK 21, ARM)"
+
+on:
+ schedule:
+ - cron: '0 15 * * *'
+ workflow_dispatch:
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/maven_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 21
+ os: ubuntu-24.04-arm
+ arch: arm64
diff --git a/.github/workflows/build_maven_java21_macos15.yml b/.github/workflows/build_maven_java21_macos15.yml
index 377a67191ab49..14db1b1871bc4 100644
--- a/.github/workflows/build_maven_java21_macos15.yml
+++ b/.github/workflows/build_maven_java21_macos15.yml
@@ -34,7 +34,11 @@ jobs:
with:
java: 21
os: macos-15
+ arch: arm64
envs: >-
{
- "OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES"
+ "SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD": "256",
+ "SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD": "256",
+ "SPARK_TEST_HIVE_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD": "48",
+ "SPARK_TEST_HIVE_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD": "48"
}
diff --git a/.github/workflows/build_python_3.11_arm.yml b/.github/workflows/build_python_3.11_arm.yml
new file mode 100644
index 0000000000000..f0a1b467703c6
--- /dev/null
+++ b/.github/workflows/build_python_3.11_arm.yml
@@ -0,0 +1,35 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / Python-only (master, Python 3.11, ARM)"
+
+on:
+ schedule:
+ - cron: '0 22 */3 * *'
+ workflow_dispatch:
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/python_hosted_runner_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ os: ubuntu-24.04-arm
diff --git a/.github/workflows/build_python_3.11_classic_only.yml b/.github/workflows/build_python_3.11_classic_only.yml
new file mode 100644
index 0000000000000..f7f6a24543a2c
--- /dev/null
+++ b/.github/workflows/build_python_3.11_classic_only.yml
@@ -0,0 +1,47 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / Python-only Classic-only (master, Python 3.11)"
+
+on:
+ schedule:
+ - cron: '0 0 */3 * *'
+ workflow_dispatch:
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 17
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "PYSPARK_IMAGE_TO_TEST": "python-311-classic-only",
+ "PYTHON_TO_TEST": "python3.11"
+ }
+ jobs: >-
+ {
+ "pyspark": "true",
+ "pyspark-pandas": "true"
+ }
diff --git a/.github/workflows/build_python_3.11_macos.yml b/.github/workflows/build_python_3.11_macos.yml
index 57902e4871ffa..9566bfd8271d1 100644
--- a/.github/workflows/build_python_3.11_macos.yml
+++ b/.github/workflows/build_python_3.11_macos.yml
@@ -29,5 +29,5 @@ jobs:
permissions:
packages: write
name: Run
- uses: ./.github/workflows/python_macos_test.yml
+ uses: ./.github/workflows/python_hosted_runner_test.yml
if: github.repository == 'apache/spark'
diff --git a/.github/workflows/build_python_3.13_nogil.yml b/.github/workflows/build_python_3.13_nogil.yml
new file mode 100644
index 0000000000000..6fc717cc118fc
--- /dev/null
+++ b/.github/workflows/build_python_3.13_nogil.yml
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / Python-only (master, Python 3.13 no GIL)"
+
+on:
+ schedule:
+ - cron: '0 19 */3 * *'
+ workflow_dispatch:
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 17
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "PYSPARK_IMAGE_TO_TEST": "python-313-nogil",
+ "PYTHON_TO_TEST": "python3.13t",
+ "PYTHON_GIL": "0"
+ }
+ jobs: >-
+ {
+ "pyspark": "true",
+ "pyspark-pandas": "true"
+ }
diff --git a/.github/workflows/build_python_connect35.yml b/.github/workflows/build_python_connect35.yml
index 66bd816d39bed..9a5723840dc36 100644
--- a/.github/workflows/build_python_connect35.yml
+++ b/.github/workflows/build_python_connect35.yml
@@ -90,7 +90,8 @@ jobs:
# Start a Spark Connect server for local
PYTHONPATH="python/lib/pyspark.zip:python/lib/py4j-0.10.9.9-src.zip:$PYTHONPATH" ./sbin/start-connect-server.sh \
--driver-java-options "-Dlog4j.configurationFile=file:$GITHUB_WORKSPACE/conf/log4j2.properties" \
- --jars "`find connector/protobuf/target -name spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name spark-avro*SNAPSHOT.jar`"
+ --jars "`find connector/protobuf/target -name spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name spark-avro*SNAPSHOT.jar`" \
+ --conf spark.sql.execution.arrow.pyspark.validateSchema.enabled=false
# Checkout to branch-3.5 to use the tests in branch-3.5.
cd ..
diff --git a/.github/workflows/build_python_numpy_2.1.3.yml b/.github/workflows/build_python_numpy_2.1.3.yml
new file mode 100644
index 0000000000000..345b97c282a02
--- /dev/null
+++ b/.github/workflows/build_python_numpy_2.1.3.yml
@@ -0,0 +1,47 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / Python-only (master, Python 3.11, Numpy 2.1.3)"
+
+on:
+ schedule:
+ - cron: '0 3 */3 * *'
+ workflow_dispatch:
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 17
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "PYSPARK_IMAGE_TO_TEST": "numpy-213",
+ "PYTHON_TO_TEST": "python3.11"
+ }
+ jobs: >-
+ {
+ "pyspark": "true",
+ "pyspark-pandas": "true"
+ }
diff --git a/.github/workflows/build_sparkr_window.yml b/.github/workflows/build_sparkr_window.yml
index 20362da061a70..e3ef9d7ba0752 100644
--- a/.github/workflows/build_sparkr_window.yml
+++ b/.github/workflows/build_sparkr_window.yml
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-name: "Build / SparkR-only (master, 4.4.2, windows-2022)"
+name: "Build / SparkR-only (master, 4.4.3, windows-2022)"
on:
schedule:
@@ -51,10 +51,10 @@ jobs:
with:
distribution: zulu
java-version: 17
- - name: Install R 4.4.2
+ - name: Install R 4.4.3
uses: r-lib/actions/setup-r@v2
with:
- r-version: 4.4.2
+ r-version: 4.4.3
- name: Install R dependencies
run: |
Rscript -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival', 'arrow', 'xml2'), repos='https://cloud.r-project.org/')"
diff --git a/.github/workflows/build_uds.yml b/.github/workflows/build_uds.yml
new file mode 100644
index 0000000000000..29aadcecf6d90
--- /dev/null
+++ b/.github/workflows/build_uds.yml
@@ -0,0 +1,53 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / Unix Domain Socket (master, Hadoop 3, JDK 17, Scala 2.13)"
+
+on:
+ schedule:
+ - cron: '0 1 */3 * *'
+ workflow_dispatch:
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 17
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "PYSPARK_IMAGE_TO_TEST": "python-311",
+ "PYTHON_TO_TEST": "python3.11",
+ "PYSPARK_UDS_MODE": "true",
+ }
+ jobs: >-
+ {
+ "build": "true",
+ "docs": "true",
+ "pyspark": "true",
+ "sparkr": "true",
+ "tpcds-1g": "true",
+ "docker-integration-tests": "true",
+ "yarn": "true"
+ }
diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml
index 7fdfc1c6866ca..857a6d8ece06c 100644
--- a/.github/workflows/maven_test.yml
+++ b/.github/workflows/maven_test.yml
@@ -41,6 +41,11 @@ on:
required: false
type: string
default: ubuntu-latest
+ arch:
+ description: The target architecture (x86, x64, arm64) of the Python or PyPy interpreter.
+ required: false
+ type: string
+ default: x64
envs:
description: Additional environment variables to set when running the tests. Should be in JSON format.
required: false
@@ -169,12 +174,10 @@ jobs:
# We should install one Python that is higher than 3+ for SQL and Yarn because:
# - SQL component also has Python related tests, for example, IntegratedUDFTestUtils.
# - Yarn has a Python specific test too, for example, YarnClusterSuite.
- # macos (14) already has its Python installed, see also SPARK-47096 and
- # https://github.com/actions/runner-images/blob/main/images/macos/macos-14-Readme.md
- if: contains(inputs.os, 'ubuntu') && (contains(matrix.modules, 'resource-managers#yarn') || (contains(matrix.modules, 'sql#core')) || contains(matrix.modules, 'connect'))
+ if: contains(matrix.modules, 'resource-managers#yarn') || (contains(matrix.modules, 'sql#core')) || contains(matrix.modules, 'connect')
with:
python-version: '3.11'
- architecture: x64
+ architecture: ${{ inputs.arch }}
- name: Install Python packages (Python 3.11)
if: contains(matrix.modules, 'resource-managers#yarn') || (contains(matrix.modules, 'sql#core')) || contains(matrix.modules, 'connect')
run: |
@@ -191,6 +194,12 @@ jobs:
# Replace with the real module name, for example, connector#kafka-0-10 -> connector/kafka-0-10
export TEST_MODULES=`echo "$MODULES_TO_TEST" | sed -e "s%#%/%g"`
./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pjvm-profiler -Pspark-ganglia-lgpl -Pkinesis-asl -Djava.version=${JAVA_VERSION/-ea} clean install
+
+ if [ "$MODULES_TO_TEST" != "connect" ]; then
+ echo "Clean up the assembly module before maven testing"
+ ./build/mvn $MAVEN_CLI_OPTS clean -pl assembly
+ fi
+
if [[ "$INCLUDED_TAGS" != "" ]]; then
./build/mvn $MAVEN_CLI_OPTS -pl "$TEST_MODULES" -Pyarn -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pjvm-profiler -Pspark-ganglia-lgpl -Pkinesis-asl -Djava.version=${JAVA_VERSION/-ea} -Dtest.include.tags="$INCLUDED_TAGS" test -fae
elif [[ "$MODULES_TO_TEST" == "connect" ]]; then
diff --git a/.github/workflows/python_macos_test.yml b/.github/workflows/python_hosted_runner_test.yml
similarity index 87%
rename from .github/workflows/python_macos_test.yml
rename to .github/workflows/python_hosted_runner_test.yml
index 2cffb68419e8a..03d423acca715 100644
--- a/.github/workflows/python_macos_test.yml
+++ b/.github/workflows/python_hosted_runner_test.yml
@@ -40,6 +40,16 @@ on:
required: false
type: string
default: hadoop3
+ os:
+ description: OS to run this build.
+ required: false
+ type: string
+ default: macos-15
+ arch:
+ description: The target architecture (x86, x64, arm64) of the Python or PyPy interpreter.
+ required: false
+ type: string
+ default: arm64
envs:
description: Additional environment variables to set when running the tests. Should be in JSON format.
required: false
@@ -48,7 +58,7 @@ on:
jobs:
build:
name: "PySpark test on macos: ${{ matrix.modules }}"
- runs-on: macos-15
+ runs-on: ${{ inputs.os }}
strategy:
fail-fast: false
matrix:
@@ -129,11 +139,16 @@ jobs:
with:
distribution: zulu
java-version: ${{ matrix.java }}
+ - name: Install Python ${{matrix.python}}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{matrix.python}}
+ architecture: ${{ inputs.arch }}
- name: Install Python packages (Python ${{matrix.python}})
run: |
python${{matrix.python}} -m pip install --ignore-installed 'blinker>=1.6.2'
python${{matrix.python}} -m pip install --ignore-installed 'six==1.16.0'
- python${{matrix.python}} -m pip install numpy 'pyarrow>=15.0.0' 'six==1.16.0' 'pandas==2.2.3' scipy 'plotly<6.0.0' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' unittest-xml-reporting && \
+ python${{matrix.python}} -m pip install numpy 'pyarrow>=19.0.0' 'six==1.16.0' 'pandas==2.2.3' scipy 'plotly<6.0.0' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' unittest-xml-reporting && \
python${{matrix.python}} -m pip install 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.1' 'googleapis-common-protos==1.65.0' 'graphviz==0.20.3' && \
python${{matrix.python}} -m pip cache purge
- name: List Python packages
@@ -152,12 +167,12 @@ jobs:
if: always()
uses: actions/upload-artifact@v4
with:
- name: test-results-${{ matrix.modules }}--${{ matrix.java }}-${{ inputs.hadoop }}-hive2.3-${{ env.PYTHON_TO_TEST }}
+ name: test-results-${{ inputs.os }}-${{ matrix.modules }}--${{ matrix.java }}-${{ inputs.hadoop }}-hive2.3-${{ env.PYTHON_TO_TEST }}
path: "**/target/test-reports/*.xml"
- name: Upload unit tests log files
env: ${{ fromJSON(inputs.envs) }}
if: ${{ !success() }}
uses: actions/upload-artifact@v4
with:
- name: unit-tests-log-${{ matrix.modules }}--${{ matrix.java }}-${{ inputs.hadoop }}-hive2.3-${{ env.PYTHON_TO_TEST }}
+ name: unit-tests-log-${{ inputs.os }}-${{ matrix.modules }}--${{ matrix.java }}-${{ inputs.hadoop }}-hive2.3-${{ env.PYTHON_TO_TEST }}
path: "**/target/unit-tests.log"
diff --git a/.gitignore b/.gitignore
index 0a4138ec26948..d7632abb49f0b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -122,5 +122,6 @@ node_modules
# For Antlr
sql/api/gen/
+sql/api/src/main/gen/
sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens
sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/
diff --git a/LICENSE-binary b/LICENSE-binary
index c8bd77e7ae2ec..0c3c7aecb71ac 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -215,8 +215,10 @@ com.github.joshelser:dropwizard-metrics-hadoop-metrics2-reporter
com.google.code.findbugs:jsr305
com.google.code.gson:gson
com.google.crypto.tink:tink
+com.google.errorprone:error_prone_annotations
com.google.flatbuffers:flatbuffers-java
com.google.guava:guava
+com.google.j2objc:j2objc-annotations
com.jamesmurty.utils:java-xmlbuilder
com.ning:compress-lzf
com.squareup.okhttp3:logging-interceptor
@@ -226,7 +228,7 @@ com.tdunning:json
com.twitter:chill-java
com.twitter:chill_2.13
com.univocity:univocity-parsers
-com.zaxxer.HikariCP
+com.zaxxer:HikariCP
commons-cli:commons-cli
commons-codec:commons-codec
commons-collections:commons-collections
@@ -273,6 +275,7 @@ io.jsonwebtoken:jjwt-jackson
io.netty:netty-all
io.netty:netty-buffer
io.netty:netty-codec
+io.netty:netty-codec-dns
io.netty:netty-codec-http
io.netty:netty-codec-http2
io.netty:netty-codec-socks
@@ -280,6 +283,7 @@ io.netty:netty-common
io.netty:netty-handler
io.netty:netty-handler-proxy
io.netty:netty-resolver
+io.netty:netty-resolver-dns
io.netty:netty-tcnative-boringssl-static
io.netty:netty-tcnative-classes
io.netty:netty-transport
@@ -328,7 +332,6 @@ org.apache.hive:hive-cli
org.apache.hive:hive-common
org.apache.hive:hive-exec
org.apache.hive:hive-jdbc
-org.apache.hive:hive-llap-common
org.apache.hive:hive-metastore
org.apache.hive:hive-serde
org.apache.hive:hive-service-rpc
@@ -384,6 +387,8 @@ org.glassfish.jersey.core:jersey-client
org.glassfish.jersey.core:jersey-common
org.glassfish.jersey.core:jersey-server
org.glassfish.jersey.inject:jersey-hk2
+org.javassist:javassist
+org.jetbrains:annotations
org.json4s:json4s-ast_2.13
org.json4s:json4s-core_2.13
org.json4s:json4s-jackson-core_2.13
@@ -440,6 +445,8 @@ jline:jline
org.jodd:jodd-core
pl.edu.icm:JLargeArrays
+python/pyspark/errors/exceptions/tblib.py
+
BSD 3-Clause
------------
diff --git a/NOTICE-binary b/NOTICE-binary
index 3f36596b9d6d6..a3f302b1cb04d 100644
--- a/NOTICE-binary
+++ b/NOTICE-binary
@@ -592,9 +592,6 @@ Copyright 2015 The Apache Software Foundation
Apache Extras Companion for log4j 1.2.
Copyright 2007 The Apache Software Foundation
-Hive Metastore
-Copyright 2016 The Apache Software Foundation
-
Apache Commons Logging
Copyright 2003-2013 The Apache Software Foundation
@@ -969,12 +966,6 @@ The Derby build relies on a jar file supplied by the JSON Simple
project, hosted at https://code.google.com/p/json-simple/.
The JSON simple jar file is licensed under the Apache 2.0 License.
-Hive CLI
-Copyright 2016 The Apache Software Foundation
-
-Hive JDBC
-Copyright 2016 The Apache Software Foundation
-
Chill is a set of Scala extensions for Kryo.
Copyright 2012 Twitter, Inc.
@@ -1056,9 +1047,6 @@ Copyright 2019 The Apache Software Foundation
Hive Query Language
Copyright 2019 The Apache Software Foundation
-Hive Llap Common
-Copyright 2019 The Apache Software Foundation
-
Hive Metastore
Copyright 2019 The Apache Software Foundation
@@ -1083,8 +1071,6 @@ Copyright 2019 The Apache Software Foundation
Hive Storage API
Copyright 2018 The Apache Software Foundation
-Hive Vector-Code-Gen Utilities
-Copyright 2019 The Apache Software Foundation
Apache License
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index eea83aa5ab527..0242e71149785 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -181,7 +181,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
parallelism <- as.integer(numSlices)
jserver <- newJObject("org.apache.spark.api.r.RParallelizeServer", sc, parallelism)
authSecret <- callJMethod(jserver, "secret")
- port <- callJMethod(jserver, "port")
+ port <- callJMethod(jserver, "connInfo")
conn <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
doServerAuth(conn, authSecret)
diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh
index 861ebbbcb7a33..4ab35ad28751e 100644
--- a/bin/load-spark-env.sh
+++ b/bin/load-spark-env.sh
@@ -65,6 +65,6 @@ export SPARK_SCALA_VERSION=2.13
#fi
# Append jline option to enable the Beeline process to run in background.
-if [[ ( ! $(ps -o stat= -p $$) =~ "+" ) && ! ( -p /dev/stdin ) ]]; then
+if [ -e /usr/bin/tty -a "`tty`" != "not a tty" -a ! -p /dev/stdin ]; then
export SPARK_BEELINE_OPTS="$SPARK_BEELINE_OPTS -Djline.terminal=jline.UnsupportedTerminal"
fi
diff --git a/bin/spark-sql b/bin/spark-sql
index b08b944ebd319..6b898f2913897 100755
--- a/bin/spark-sql
+++ b/bin/spark-sql
@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+export SPARK_CONNECT_MODE=0
if [ -z "${SPARK_HOME}" ]; then
source "$(dirname "$0")"/find-spark-home
diff --git a/bin/spark-sql2.cmd b/bin/spark-sql2.cmd
index c34a3c5aa0739..0dc6edb1a1c4a 100644
--- a/bin/spark-sql2.cmd
+++ b/bin/spark-sql2.cmd
@@ -18,6 +18,8 @@ rem limitations under the License.
rem
rem Figure out where the Spark framework is installed
+set SPARK_CONNECT_MODE=0
+
call "%~dp0find-spark-home.cmd"
set _SPARK_CMD_USAGE=Usage: .\bin\spark-sql [options] [cli option]
diff --git a/bin/sparkR b/bin/sparkR
index 8ecc755839fe3..a99b1dd287a19 100755
--- a/bin/sparkR
+++ b/bin/sparkR
@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+export SPARK_CONNECT_MODE=0
if [ -z "${SPARK_HOME}" ]; then
source "$(dirname "$0")"/find-spark-home
diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd
index 446f0c30bfe82..a047f756a0bfc 100644
--- a/bin/sparkR2.cmd
+++ b/bin/sparkR2.cmd
@@ -18,6 +18,8 @@ rem limitations under the License.
rem
rem Figure out where the Spark framework is installed
+set SPARK_CONNECT_MODE=0
+
call "%~dp0find-spark-home.cmd"
call "%SPARK_HOME%\bin\load-spark-env.cmd"
diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index ec248b7b4602c..896a4192cffff 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -80,6 +80,14 @@
com.twitterchill_${scala.binary.version}
+
+ com.esotericsoftware
+ kryo-shaded
+
+
+ org.objenesis
+ objenesis
+
diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json
index 691e40a75dcd9..8962cc3821f36 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -11,6 +11,12 @@
],
"sqlState" : "42845"
},
+ "AGGREGATE_OUT_OF_MEMORY" : {
+ "message" : [
+ "No enough memory for aggregation"
+ ],
+ "sqlState" : "82001"
+ },
"ALL_PARAMETERS_MUST_BE_NAMED" : {
"message" : [
"Using name parameterized queries requires all parameters to be named. Parameters missing names: ."
@@ -82,6 +88,12 @@
],
"sqlState" : "22003"
},
+ "ARROW_TYPE_MISMATCH" : {
+ "message" : [
+ "Invalid schema from : expected , got ."
+ ],
+ "sqlState" : "42K0G"
+ },
"ARTIFACT_ALREADY_EXISTS" : {
"message" : [
"The artifact already exists. Please choose a different name for the new artifact because it cannot be overwritten."
@@ -292,6 +304,11 @@
"Error reading streaming state file of does not exist. If the stream job is restarted with a new or updated state operation, please create a new checkpoint location or clear the existing checkpoint location."
]
},
+ "FAILED_TO_GET_CHANGELOG_WRITER" : {
+ "message" : [
+ "Failed to get the changelog writer for state store at version ."
+ ]
+ },
"HDFS_STORE_PROVIDER_OUT_OF_MEMORY" : {
"message" : [
"Could not load HDFS state store with id because of an out of memory exception."
@@ -388,6 +405,12 @@
],
"sqlState" : "22018"
},
+ "CANNOT_PARSE_TIME" : {
+ "message" : [
+ "The input string cannot be parsed to a TIME value because it does not match to the datetime format ."
+ ],
+ "sqlState" : "22010"
+ },
"CANNOT_PARSE_TIMESTAMP" : {
"message" : [
". Use to tolerate invalid input string and return NULL instead."
@@ -788,6 +811,20 @@
},
"sqlState" : "XX000"
},
+ "CONSTRAINT_ALREADY_EXISTS" : {
+ "message" : [
+ "Constraint '' already exists. Please delete the existing constraint first.",
+ "Existing constraint:",
+ ""
+ ],
+ "sqlState" : "42710"
+ },
+ "CONSTRAINT_DOES_NOT_EXIST" : {
+ "message" : [
+ "Cannot drop nonexistent constraint from table ."
+ ],
+ "sqlState" : "42704"
+ },
"CONVERSION_INVALID_INPUT" : {
"message" : [
"The value () cannot be converted to because it is malformed. Correct the value as per the syntax, or change its format. Use to tolerate malformed input and return NULL instead."
@@ -971,6 +1008,11 @@
"Input schema can only contain STRING as a key type for a MAP."
]
},
+ "INVALID_XML_SCHEMA" : {
+ "message" : [
+ "Input schema must be a struct or a variant."
+ ]
+ },
"IN_SUBQUERY_DATA_TYPE_MISMATCH" : {
"message" : [
"The data type of one or more elements in the left hand side of an IN subquery is not compatible with the data type of the output of the subquery. Mismatched columns: [], left side: [], right side: []."
@@ -1167,8 +1209,20 @@
},
"DATETIME_FIELD_OUT_OF_BOUNDS" : {
"message" : [
- ". If necessary set to \"false\" to bypass this error."
+ "."
],
+ "subClass" : {
+ "WITHOUT_SUGGESTION" : {
+ "message" : [
+ ""
+ ]
+ },
+ "WITH_SUGGESTION" : {
+ "message" : [
+ "If necessary set to \"false\" to bypass this error."
+ ]
+ }
+ },
"sqlState" : "22023"
},
"DATETIME_OVERFLOW" : {
@@ -2328,6 +2382,12 @@
},
"sqlState" : "22022"
},
+ "INVALID_CONSTRAINT_CHARACTERISTICS" : {
+ "message" : [
+ "Constraint characteristics [] are duplicated or conflict with each other."
+ ],
+ "sqlState" : "42613"
+ },
"INVALID_CORRUPT_RECORD_TYPE" : {
"message" : [
"The column for corrupt records must have the nullable STRING type, but got ."
@@ -2381,6 +2441,11 @@
"message" : [
"Cannot detect a seconds fraction pattern of variable length. Please make sure the pattern contains 'S', and does not contain illegal characters."
]
+ },
+ "WITH_SUGGESTION" : {
+ "message" : [
+ "You can form a valid datetime pattern with the guide from '/sql-ref-datetime-pattern.html'."
+ ]
}
},
"sqlState" : "22007"
@@ -2884,6 +2949,12 @@
],
"sqlState" : "F0000"
},
+ "INVALID_KRYO_SERIALIZER_NO_DATA" : {
+ "message" : [
+ "The object '' is invalid or malformed to using ."
+ ],
+ "sqlState" : "22002"
+ },
"INVALID_LABEL_USAGE" : {
"message" : [
"The usage of the label is invalid."
@@ -2926,6 +2997,11 @@
"message" : [
"A higher order function expects arguments, but got ."
]
+ },
+ "PARAMETER_DOES_NOT_ACCEPT_LAMBDA_FUNCTION" : {
+ "message" : [
+ "You passed a lambda function to a parameter that does not accept it. Please check if lambda function argument is in the correct position."
+ ]
}
},
"sqlState" : "42K0D"
@@ -3297,7 +3373,7 @@
},
"INVALID_SINGLE_VARIANT_COLUMN" : {
"message" : [
- "The `singleVariantColumn` option cannot be used if there is also a user specified schema."
+ "User specified schema is invalid when the `singleVariantColumn` option is enabled. The schema must either be a variant field, or a variant field plus a corrupt column field."
],
"sqlState" : "42613"
},
@@ -3368,6 +3444,16 @@
"ANALYZE TABLE(S) ... COMPUTE STATISTICS ... must be either NOSCAN or empty."
]
},
+ "CREATE_FUNC_WITH_COLUMN_CONSTRAINTS" : {
+ "message" : [
+ "CREATE FUNCTION with constraints on parameters is not allowed."
+ ]
+ },
+ "CREATE_FUNC_WITH_GENERATED_COLUMNS_AS_PARAMETERS" : {
+ "message" : [
+ "CREATE FUNCTION with generated columns as parameters is not allowed."
+ ]
+ },
"CREATE_ROUTINE_WITH_IF_NOT_EXISTS_AND_REPLACE" : {
"message" : [
"Cannot create a routine with both IF NOT EXISTS and REPLACE specified."
@@ -3820,6 +3906,25 @@
},
"sqlState" : "22023"
},
+ "MALFORMED_STATE_IN_RATE_PER_MICRO_BATCH_SOURCE" : {
+ "message" : [
+ "Malformed state in RatePerMicroBatch source."
+ ],
+ "subClass" : {
+ "INVALID_OFFSET" : {
+ "message" : [
+ "The offset value is invalid: startOffset should less than or equal to the endOffset, but startOffset() > endOffset()."
+ ]
+ },
+ "INVALID_TIMESTAMP" : {
+ "message" : [
+ "The timestamp value is invalid: startTimestamp should less than or equal to the endTimestamp, but startTimestamp() > endTimestamp().",
+ "This could happen when the streaming query is restarted with a newer `startingTimestamp` and reprocess the first batch (i.e. batch 0). Please consider using a new checkpoint location."
+ ]
+ }
+ },
+ "sqlState" : "22000"
+ },
"MALFORMED_VARIANT" : {
"message" : [
"Variant binary is malformed. Please check the data source is valid."
@@ -3889,6 +3994,12 @@
],
"sqlState" : "42P20"
},
+ "MULTIPLE_PRIMARY_KEYS" : {
+ "message" : [
+ "Multiple primary keys are defined: . Please ensure that only one primary key is defined for the table."
+ ],
+ "sqlState" : "42P16"
+ },
"MULTIPLE_QUERY_RESULT_CLAUSES_WITH_PIPE_OPERATORS" : {
"message" : [
" and cannot coexist in the same SQL pipe operator using '|>'. Please separate the multiple result clauses into separate pipe operators and then retry the query again."
@@ -4386,6 +4497,12 @@
],
"sqlState" : "XXKD0"
},
+ "POINTER_ARRAY_OUT_OF_MEMORY" : {
+ "message" : [
+ "Not enough memory to grow pointer array"
+ ],
+ "sqlState" : "82002"
+ },
"PROTOBUF_DEPENDENCY_NOT_FOUND" : {
"message" : [
"Could not find dependency: ."
@@ -4447,6 +4564,18 @@
],
"sqlState" : "38000"
},
+ "RECURSION_LEVEL_LIMIT_EXCEEDED" : {
+ "message" : [
+ "Recursion level limit reached but query has not exhausted, try increasing 'spark.sql.cteRecursionLevelLimit'"
+ ],
+ "sqlState" : "42836"
+ },
+ "RECURSION_ROW_LIMIT_EXCEEDED" : {
+ "message" : [
+ "Recursion row limit reached but query has not exhausted, try increasing 'spark.sql.cteRecursionRowLimit'"
+ ],
+ "sqlState" : "42836"
+ },
"RECURSIVE_CTE_IN_LEGACY_MODE" : {
"message" : [
"Recursive definitions cannot be used in legacy CTE precedence mode (spark.sql.legacy.ctePrecedencePolicy=LEGACY)."
@@ -4639,6 +4768,12 @@
],
"sqlState" : "42601"
},
+ "SPILL_OUT_OF_MEMORY" : {
+ "message" : [
+ "Error while calling spill() on : "
+ ],
+ "sqlState" : "82003"
+ },
"SQL_CONF_NOT_FOUND" : {
"message" : [
"The SQL config cannot be found. Please verify that the config exists."
@@ -4783,6 +4918,12 @@
],
"sqlState" : "42802"
},
+ "STATE_STORE_OPERATION_OUT_OF_ORDER" : {
+ "message" : [
+ "Streaming stateful operator attempted to access state store out of order. This is a bug, please retry. error_msg="
+ ],
+ "sqlState" : "XXKST"
+ },
"STATE_STORE_PROVIDER_DOES_NOT_SUPPORT_FINE_GRAINED_STATE_REPLAY" : {
"message" : [
"The given State Store Provider does not extend org.apache.spark.sql.execution.streaming.state.SupportsFineGrainedReplay.",
@@ -5206,6 +5347,12 @@
],
"sqlState" : "42846"
},
+ "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE" : {
+ "message" : [
+ "The UNION operator is not yet supported within recursive common table expressions (WITH clauses that refer to themselves, directly or indirectly). Please use UNION ALL instead."
+ ],
+ "sqlState" : "42836"
+ },
"UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT" : {
"message" : [
"Unknown primitive type with id was found in a variant value."
@@ -5457,6 +5604,12 @@
},
"sqlState" : "0A000"
},
+ "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC" : {
+ "message" : [
+ "Constraint characteristic '' is not supported for constraint type ''."
+ ],
+ "sqlState" : "0A000"
+ },
"UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY" : {
"message" : [
"Unsupported data source type for direct query on files: "
@@ -6195,6 +6348,12 @@
],
"sqlState" : "42000"
},
+ "UNSUPPORTED_TIME_PRECISION" : {
+ "message" : [
+ "The seconds precision of the TIME data type is out of the supported range [0, 6]."
+ ],
+ "sqlState" : "0A001"
+ },
"UNSUPPORTED_TYPED_LITERAL" : {
"message" : [
"Literals of the type are not supported. Supported types are ."
@@ -7840,11 +7999,6 @@
"Conflict found: Field differs from derived from ."
]
},
- "_LEGACY_ERROR_TEMP_2130" : {
- "message" : [
- "Fail to recognize '' pattern in the DateTimeFormatter. You can form a valid datetime pattern with the guide from '/sql-ref-datetime-pattern.html'."
- ]
- },
"_LEGACY_ERROR_TEMP_2131" : {
"message" : [
"Exception when registering StreamingQueryListener."
@@ -8487,11 +8641,6 @@
"duration() called on unfinished task"
]
},
- "_LEGACY_ERROR_TEMP_3027" : {
- "message" : [
- "Unrecognized : "
- ]
- },
"_LEGACY_ERROR_TEMP_3028" : {
"message" : [
""
@@ -9315,21 +9464,6 @@
"Doesn't support month or year interval: "
]
},
- "_LEGACY_ERROR_TEMP_3300" : {
- "message" : [
- "error while calling spill() on : "
- ]
- },
- "_LEGACY_ERROR_TEMP_3301" : {
- "message" : [
- "Not enough memory to grow pointer array"
- ]
- },
- "_LEGACY_ERROR_TEMP_3302" : {
- "message" : [
- "No enough memory for aggregation"
- ]
- },
"_LEGACY_ERROR_USER_RAISED_EXCEPTION" : {
"message" : [
""
diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json
index 2e6c83560b014..3cd8163deaf85 100644
--- a/common/utils/src/main/resources/error/error-states.json
+++ b/common/utils/src/main/resources/error/error-states.json
@@ -6504,7 +6504,24 @@
"standard": "N",
"usedBy": ["Oracle"]
},
-
+ "82001": {
+ "description": "No enough memory for aggregation",
+ "origin": "Spark",
+ "standard": "N",
+ "usedBy": ["Spark"]
+ },
+ "82002": {
+ "description": "Not enough memory to grow pointer array",
+ "origin": "Spark",
+ "standard": "N",
+ "usedBy": ["Spark"]
+ },
+ "82003": {
+ "description": "Error while calling spill()",
+ "origin": "Spark",
+ "standard": "N",
+ "usedBy": ["Spark"]
+ },
"82100": {
"description": "out of memory (could not allocate)",
"origin": "Oracle",
diff --git a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
index e2dd0da1aac85..85d460f618a79 100644
--- a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
+++ b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
@@ -62,7 +62,8 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
}
if (util.SparkEnvUtils.isTesting) {
val placeHoldersNum = ErrorClassesJsonReader.TEMPLATE_REGEX.findAllIn(messageTemplate).length
- if (placeHoldersNum < sanitizedParameters.size) {
+ if (placeHoldersNum < sanitizedParameters.size &&
+ !ErrorClassesJsonReader.MORE_PARAMS_ALLOWLIST.contains(errorClass)) {
throw SparkException.internalError(
s"Found unused message parameters of the error class '$errorClass'. " +
s"Its error message format has $placeHoldersNum placeholders, " +
@@ -123,6 +124,8 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
private object ErrorClassesJsonReader {
private val TEMPLATE_REGEX = "<([a-zA-Z0-9_-]+)>".r
+ private val MORE_PARAMS_ALLOWLIST = Array("CAST_INVALID_INPUT", "CAST_OVERFLOW")
+
private val mapper: JsonMapper = JsonMapper.builder()
.addModule(DefaultScalaModule)
.build()
diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
index 318f32c52b904..1f997592dbfb7 100644
--- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
+++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
@@ -139,6 +139,7 @@ private[spark] object LogKeys {
case object CLUSTER_LABEL extends LogKey
case object CLUSTER_LEVEL extends LogKey
case object CLUSTER_WEIGHT extends LogKey
+ case object CODE extends LogKey
case object CODEC_LEVEL extends LogKey
case object CODEC_NAME extends LogKey
case object CODEGEN_STAGE_ID extends LogKey
@@ -510,6 +511,7 @@ private[spark] object LogKeys {
case object NUM_ITERATIONS extends LogKey
case object NUM_KAFKA_PULLS extends LogKey
case object NUM_KAFKA_RECORDS_PULLED extends LogKey
+ case object NUM_LAGGING_STORES extends LogKey
case object NUM_LEADING_SINGULAR_VALUES extends LogKey
case object NUM_LEFT_PARTITION_VALUES extends LogKey
case object NUM_LOADED_ENTRIES extends LogKey
@@ -704,6 +706,7 @@ private[spark] object LogKeys {
case object RIGHT_EXPR extends LogKey
case object RIGHT_LOGICAL_PLAN_STATS_SIZE_IN_BYTES extends LogKey
case object RMSE extends LogKey
+ case object ROCKS_DB_FILE_MAPPING extends LogKey
case object ROCKS_DB_LOG_LEVEL extends LogKey
case object ROCKS_DB_LOG_MESSAGE extends LogKey
case object RPC_ADDRESS extends LogKey
@@ -749,6 +752,9 @@ private[spark] object LogKeys {
case object SLEEP_TIME extends LogKey
case object SLIDE_DURATION extends LogKey
case object SMALLEST_CLUSTER_INDEX extends LogKey
+ case object SNAPSHOT_EVENT extends LogKey
+ case object SNAPSHOT_EVENT_TIME_DELTA extends LogKey
+ case object SNAPSHOT_EVENT_VERSION_DELTA extends LogKey
case object SNAPSHOT_VERSION extends LogKey
case object SOCKET_ADDRESS extends LogKey
case object SOURCE extends LogKey
diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala
index 110c5f0934286..cc5d0281829d0 100644
--- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala
+++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala
@@ -26,6 +26,7 @@ import org.apache.logging.log4j.core.appender.ConsoleAppender
import org.apache.logging.log4j.core.config.DefaultConfiguration
import org.apache.logging.log4j.core.filter.AbstractFilter
import org.slf4j.{Logger, LoggerFactory}
+import org.slf4j.event.{Level => Slf4jLevel}
import org.apache.spark.internal.Logging.SparkShellLoggingFilter
import org.apache.spark.internal.LogKeys
@@ -87,7 +88,7 @@ object MDC {
* Wrapper class for log messages that include a logging context.
* This is used as the return type of the string interpolator `LogStringContext`.
*/
-case class MessageWithContext(message: String, context: java.util.HashMap[String, String]) {
+case class MessageWithContext(message: String, context: java.util.Map[String, String]) {
def +(mdc: MessageWithContext): MessageWithContext = {
val resultMap = new java.util.HashMap(context)
resultMap.putAll(mdc.context)
@@ -105,7 +106,7 @@ class LogEntry(messageWithContext: => MessageWithContext) {
def message: String = cachedMessageWithContext.message
- def context: java.util.HashMap[String, String] = cachedMessageWithContext.context
+ def context: java.util.Map[String, String] = cachedMessageWithContext.context
}
/**
@@ -166,7 +167,7 @@ trait Logging {
}
}
- protected def withLogContext(context: java.util.HashMap[String, String])(body: => Unit): Unit = {
+ protected def withLogContext(context: java.util.Map[String, String])(body: => Unit): Unit = {
// put into thread context only when structured logging is enabled
val closeableThreadContextOpt = if (Logging.isStructuredLoggingEnabled) {
Some(CloseableThreadContext.putAll(context))
@@ -307,6 +308,16 @@ trait Logging {
log.isTraceEnabled
}
+ protected def logBasedOnLevel(level: Slf4jLevel)(f: => MessageWithContext): Unit = {
+ level match {
+ case Slf4jLevel.TRACE => logTrace(f.message)
+ case Slf4jLevel.DEBUG => logDebug(f.message)
+ case Slf4jLevel.INFO => logInfo(f)
+ case Slf4jLevel.WARN => logWarning(f)
+ case Slf4jLevel.ERROR => logError(f)
+ }
+ }
+
protected def initializeLogIfNecessary(isInterpreter: Boolean): Unit = {
initializeLogIfNecessary(isInterpreter, silent = false)
}
diff --git a/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
index d3e975d1782f0..0f8a6b5fe334d 100644
--- a/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
+++ b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
@@ -17,6 +17,7 @@
package org.apache.spark.internal.config
+import java.util.Locale
import java.util.concurrent.TimeUnit
import java.util.regex.PatternSyntaxException
@@ -46,6 +47,25 @@ private object ConfigHelpers {
}
}
+ def toEnum[E <: Enumeration](s: String, enumClass: E, key: String): enumClass.Value = {
+ try {
+ enumClass.withName(s.trim.toUpperCase(Locale.ROOT))
+ } catch {
+ case _: NoSuchElementException =>
+ throw new IllegalArgumentException(
+ s"$key should be one of ${enumClass.values.mkString(", ")}, but was $s")
+ }
+ }
+
+ def toEnum[E <: Enum[E]](s: String, enumClass: Class[E], key: String): E = {
+ enumClass.getEnumConstants.find(_.name().equalsIgnoreCase(s.trim)) match {
+ case Some(enum) => enum
+ case None =>
+ throw new IllegalArgumentException(
+ s"$key should be one of ${enumClass.getEnumConstants.mkString(", ")}, but was $s")
+ }
+ }
+
def stringToSeq[T](str: String, converter: String => T): Seq[T] = {
SparkStringUtils.stringToSeq(str).map(converter)
}
@@ -271,6 +291,16 @@ private[spark] case class ConfigBuilder(key: String) {
new TypedConfigBuilder(this, v => v)
}
+ def enumConf(e: Enumeration): TypedConfigBuilder[e.Value] = {
+ checkPrependConfig
+ new TypedConfigBuilder(this, toEnum(_, e, key))
+ }
+
+ def enumConf[E <: Enum[E]](e: Class[E]): TypedConfigBuilder[E] = {
+ checkPrependConfig
+ new TypedConfigBuilder(this, toEnum(_, e, key))
+ }
+
def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = {
checkPrependConfig
new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit))
diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
index 4205b76a530aa..d2cb622b4616e 100644
--- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
+++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
@@ -216,7 +216,7 @@ static long readLong(byte[] bytes, int pos, int numBytes) {
// Read a little-endian unsigned int value from `bytes[pos, pos + numBytes)`. The value must fit
// into a non-negative int (`[0, Integer.MAX_VALUE]`).
- static int readUnsigned(byte[] bytes, int pos, int numBytes) {
+ public static int readUnsigned(byte[] bytes, int pos, int numBytes) {
checkIndex(pos, bytes.length);
checkIndex(pos + numBytes - 1, bytes.length);
int result = 0;
diff --git a/connect-examples/server-library-example/README.md b/connect-examples/server-library-example/README.md
index 6028a66cd5c7b..adf4830d58ff5 100644
--- a/connect-examples/server-library-example/README.md
+++ b/connect-examples/server-library-example/README.md
@@ -85,7 +85,7 @@ reading, writing and processing data in the custom format. The plugins (`CustomC
4. **Copy relevant JARs to the root of the unpacked Spark distribution**:
```bash
cp \
- /connect-examples/server-library-example/resources/spark-daria_2.13-1.2.3.jar \
+ /connect-examples/server-library-example/common/target/spark-daria_2.13-1.2.3.jar \
/connect-examples/server-library-example/common/target/spark-server-library-example-common-1.0.0.jar \
/connect-examples/server-library-example/server/target/spark-server-library-example-server-extension-1.0.0.jar \
.
diff --git a/connect-examples/server-library-example/resources/spark-daria_2.13-1.2.3.jar b/connect-examples/server-library-example/resources/spark-daria_2.13-1.2.3.jar
deleted file mode 100644
index 31703de77709d..0000000000000
Binary files a/connect-examples/server-library-example/resources/spark-daria_2.13-1.2.3.jar and /dev/null differ
diff --git a/connect-examples/server-library-example/server/pom.xml b/connect-examples/server-library-example/server/pom.xml
index b95c5e3d61615..b13a7537f9c13 100644
--- a/connect-examples/server-library-example/server/pom.xml
+++ b/connect-examples/server-library-example/server/pom.xml
@@ -62,9 +62,8 @@
com.github.mrpowers
- spark-daria_2.12
+ spark-daria_${scala.binary}1.2.3
- provided
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 4ddf6503d99ec..6f345e069ff78 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -21,6 +21,7 @@ import java.io._
import java.net.URI
import java.nio.file.{Files, Paths, StandardCopyOption}
import java.sql.{Date, Timestamp}
+import java.time.{LocalDate, LocalDateTime}
import java.util.UUID
import scala.jdk.CollectionConverters._
@@ -963,6 +964,36 @@ abstract class AvroSuite
}
}
+ test("SPARK-49082: Widening date to timestampNTZ in AvroDeserializer") {
+ withTempPath { tempPath =>
+ // Since timestampNTZ only supports timestamps from
+ // -290308-12-21 BCE 19:59:06 to +294247-01-10 CE 04:00:54,
+ // dates outside of this range cannot be widened to timestampNTZ
+ // and will throw an ArithmeticException.
+ val datePath = s"$tempPath/date_data"
+ val dateDf =
+ Seq(LocalDate.of(2024, 1, 1),
+ LocalDate.of(2024, 1, 2),
+ LocalDate.of(1312, 2, 27),
+ LocalDate.of(0, 1, 1),
+ LocalDate.of(-1, 12, 31),
+ LocalDate.of(-290308, 12, 22), // minimum timestampNTZ date
+ LocalDate.of(294247, 1, 10)) // maximum timestampNTZ date
+ .toDF("col")
+ dateDf.write.format("avro").save(datePath)
+ checkAnswer(
+ spark.read.schema("col TIMESTAMP_NTZ").format("avro").load(datePath),
+ Seq(Row(LocalDateTime.of(2024, 1, 1, 0, 0)),
+ Row(LocalDateTime.of(2024, 1, 2, 0, 0)),
+ Row(LocalDateTime.of(1312, 2, 27, 0, 0)),
+ Row(LocalDateTime.of(0, 1, 1, 0, 0)),
+ Row(LocalDateTime.of(-1, 12, 31, 0, 0)),
+ Row(LocalDateTime.of(-290308, 12, 22, 0, 0)),
+ Row(LocalDateTime.of(294247, 1, 10, 0, 0)))
+ )
+ }
+ }
+
test("SPARK-43380: Fix Avro data type conversion" +
" of DayTimeIntervalType to avoid producing incorrect results") {
withTempPath { path =>
@@ -3086,6 +3117,22 @@ abstract class AvroSuite
}
}
}
+
+ test("SPARK-51590: unsupported the TIME data types in Avro") {
+ withTempDir { dir =>
+ val tempDir = new File(dir, "files").getCanonicalPath
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("select time'12:01:02' as t")
+ .write.format("avro").mode("overwrite").save(tempDir)
+ },
+ condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE",
+ parameters = Map(
+ "columnName" -> "`t`",
+ "columnType" -> s"\"${TimeType().sql}\"",
+ "format" -> "Avro"))
+ }
+ }
}
class AvroV1Suite extends AvroSuite {
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
index 91a82075a3607..d12cfe7785cef 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
@@ -51,6 +51,11 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
override val namespaceOpt: Option[String] = Some("DB2INST1")
override val db = new DB2DatabaseOnDocker
+ object JdbcClientTypes {
+ val INTEGER = "INTEGER"
+ val DOUBLE = "DOUBLE"
+ }
+
override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.db2.url", db.getJdbcUrl(dockerIp, externalPort))
@@ -74,12 +79,12 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
sql(s"CREATE TABLE $tbl (ID INTEGER)")
var t = spark.table(tbl)
var expectedSchema = new StructType()
- .add("ID", IntegerType, true, defaultMetadata(IntegerType))
+ .add("ID", IntegerType, true, defaultMetadata(IntegerType, JdbcClientTypes.INTEGER))
assert(t.schema === expectedSchema)
sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE DOUBLE")
t = spark.table(tbl)
expectedSchema = new StructType()
- .add("ID", DoubleType, true, defaultMetadata(DoubleType))
+ .add("ID", DoubleType, true, defaultMetadata(DoubleType, JdbcClientTypes.DOUBLE))
assert(t.schema === expectedSchema)
// Update column type from DOUBLE to STRING
val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE VARCHAR(10)"
@@ -103,7 +108,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
s" TBLPROPERTIES('CCSID'='UNICODE')")
val t = spark.table(tbl)
val expectedSchema = new StructType()
- .add("ID", IntegerType, true, defaultMetadata(IntegerType))
+ .add("ID", IntegerType, true, defaultMetadata(IntegerType, JdbcClientTypes.INTEGER))
assert(t.schema === expectedSchema)
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
index 04637c1b55631..0ff5b9aa56567 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
@@ -41,6 +41,11 @@ import org.apache.spark.tags.DockerTest
@DockerTest
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
+ object JdbcClientTypes {
+ val INTEGER = "int"
+ val STRING = "nvarchar"
+ }
+
def getExternalEngineQuery(executedPlan: SparkPlan): String = {
getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery
}
@@ -93,18 +98,28 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
).executeUpdate()
}
+ override def testRenameColumn(tbl: String): Unit = {
+ sql(s"ALTER TABLE $tbl RENAME COLUMN ID TO RENAMED")
+ val t = spark.table(s"$tbl")
+ val expectedSchema = new StructType()
+ .add("RENAMED", StringType, true, defaultMetadata(StringType, JdbcClientTypes.STRING))
+ .add("ID1", StringType, true, defaultMetadata(StringType, JdbcClientTypes.STRING))
+ .add("ID2", StringType, true, defaultMetadata(StringType, JdbcClientTypes.STRING))
+ assert(t.schema === expectedSchema)
+ }
+
override def notSupportsTableComment: Boolean = true
override def testUpdateColumnType(tbl: String): Unit = {
sql(s"CREATE TABLE $tbl (ID INTEGER)")
var t = spark.table(tbl)
var expectedSchema = new StructType()
- .add("ID", IntegerType, true, defaultMetadata(IntegerType))
+ .add("ID", IntegerType, true, defaultMetadata(IntegerType, JdbcClientTypes.INTEGER))
assert(t.schema === expectedSchema)
sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING")
t = spark.table(tbl)
expectedSchema = new StructType()
- .add("ID", StringType, true, defaultMetadata())
+ .add("ID", StringType, true, defaultMetadata(StringType, JdbcClientTypes.STRING))
assert(t.schema === expectedSchema)
// Update column type from STRING to INTEGER
val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER"
@@ -136,16 +151,19 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") {
val df1 = sql("SELECT name FROM " +
s"$catalogName.employee WHERE ((name LIKE 'am%') = (name LIKE '%y'))")
+ checkFilterPushed(df1)
assert(df1.collect().length == 4)
val df2 = sql("SELECT name FROM " +
s"$catalogName.employee " +
"WHERE ((name NOT LIKE 'am%') = (name NOT LIKE '%y'))")
+ checkFilterPushed(df2)
assert(df2.collect().length == 4)
val df3 = sql("SELECT name FROM " +
s"$catalogName.employee " +
"WHERE (dept > 1 AND ((name LIKE 'am%') = (name LIKE '%y')))")
+ checkFilterPushed(df3)
assert(df3.collect().length == 3)
}
@@ -159,6 +177,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
|SELECT * FROM tbl
|WHERE deptString = 'first'
|""".stripMargin)
+ checkFilterPushed(df)
assert(df.collect().length == 2)
}
@@ -169,6 +188,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
|""".stripMargin
)
+ checkFilterPushed(df)
// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """
@@ -184,6 +204,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
|""".stripMargin
)
+ checkFilterPushed(df)
// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """
@@ -201,6 +222,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
|""".stripMargin
)
+ checkFilterPushed(df)
// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """
@@ -218,6 +240,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
|""".stripMargin
)
+ checkFilterPushed(df)
// scalastyle:off
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """
@@ -225,4 +248,28 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
// scalastyle:on
df.collect()
}
+
+ test("SPARK-51321: SQLServer pushdown for RPAD expression on string column") {
+ val df = sql(
+ s"""|SELECT name FROM $catalogName.employee
+ |WHERE rpad(name, 10, 'x') = 'amyxxxxxxx'
+ |""".stripMargin
+ )
+ checkFilterPushed(df)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ assert(rows(0).getString(0) === "amy")
+ }
+
+ test("SPARK-51321: SQLServer pushdown for LPAD expression on string column") {
+ val df = sql(
+ s"""|SELECT name FROM $catalogName.employee
+ |WHERE lpad(name, 10, 'x') = 'xxxxxxxamy'
+ |""".stripMargin
+ )
+ checkFilterPushed(df)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ assert(rows(0).getString(0) === "amy")
+ }
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
index 4733af882257d..22de131aca4a0 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
@@ -59,6 +59,21 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
override val catalogName: String = "mysql"
override val db = new MySQLDatabaseOnDocker
+ case class JdbcClientTypes(INTEGER: String, STRING: String)
+
+ val jdbcClientTypes: JdbcClientTypes =
+ JdbcClientTypes(INTEGER = "INT", STRING = "LONGTEXT")
+
+ override def defaultMetadata(
+ dataType: DataType = StringType,
+ jdbcClientType: String = jdbcClientTypes.STRING): Metadata =
+ new MetadataBuilder()
+ .putLong("scale", 0)
+ .putBoolean("isTimestampNTZ", false)
+ .putBoolean("isSigned", dataType.isInstanceOf[NumericType])
+ .putString("jdbcClientType", jdbcClientType)
+ .build()
+
override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.mysql", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.mysql.url", db.getJdbcUrl(dockerIp, externalPort))
@@ -90,13 +105,16 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
"('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate()
connection.prepareStatement("INSERT INTO datetime VALUES " +
"('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate()
+ // '2022-01-01' is Saturday and is in ISO year 2021.
+ connection.prepareStatement("INSERT INTO datetime VALUES " +
+ "('tom', '2022-01-01', '2022-01-01 00:00:00')").executeUpdate()
}
override def testUpdateColumnType(tbl: String): Unit = {
sql(s"CREATE TABLE $tbl (ID INTEGER)")
var t = spark.table(tbl)
var expectedSchema = new StructType()
- .add("ID", IntegerType, true, defaultMetadata(IntegerType))
+ .add("ID", IntegerType, true, defaultMetadata(IntegerType, jdbcClientTypes.INTEGER))
assert(t.schema === expectedSchema)
sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING")
t = spark.table(tbl)
@@ -150,7 +168,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
s" TBLPROPERTIES('ENGINE'='InnoDB', 'DEFAULT CHARACTER SET'='utf8')")
val t = spark.table(tbl)
val expectedSchema = new StructType()
- .add("ID", IntegerType, true, defaultMetadata(IntegerType))
+ .add("ID", IntegerType, true, defaultMetadata(IntegerType, jdbcClientTypes.INTEGER))
assert(t.schema === expectedSchema)
}
@@ -185,7 +203,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
assert(rows2(0).getString(0) === "amy")
assert(rows2(1).getString(0) === "alex")
- val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5")
+ val df3 = sql(s"SELECT name FROM $tbl WHERE month(date1) = 5")
checkFilterPushed(df3)
val rows3 = df3.collect()
assert(rows3.length === 2)
@@ -195,17 +213,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0")
checkFilterPushed(df4)
val rows4 = df4.collect()
- assert(rows4.length === 2)
+ assert(rows4.length === 3)
assert(rows4(0).getString(0) === "amy")
assert(rows4(1).getString(0) === "alex")
+ assert(rows4(2).getString(0) === "tom")
val df5 = sql(s"SELECT name FROM $tbl WHERE " +
- "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022")
+ "extract(WEEK from date1) > 10 AND extract(YEAR from date1) = 2022")
checkFilterPushed(df5)
val rows5 = df5.collect()
- assert(rows5.length === 2)
+ assert(rows5.length === 3)
assert(rows5(0).getString(0) === "amy")
assert(rows5(1).getString(0) === "alex")
+ assert(rows5(2).getString(0) === "tom")
val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " +
"AND datediff(date1, '2022-05-10') > 0")
@@ -220,11 +240,44 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
assert(rows7.length === 1)
assert(rows7(0).getString(0) === "alex")
- val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4")
- checkFilterPushed(df8)
- val rows8 = df8.collect()
- assert(rows8.length === 1)
- assert(rows8(0).getString(0) === "alex")
+ withClue("weekofyear") {
+ val woy = sql(s"SELECT weekofyear(date1) FROM $tbl WHERE name = 'tom'")
+ .collect().head.getInt(0)
+ val df = sql(s"SELECT name FROM $tbl WHERE weekofyear(date1) = $woy")
+ checkFilterPushed(df)
+ val rows = df.collect()
+ assert(rows.length === 1)
+ assert(rows(0).getString(0) === "tom")
+ }
+
+ withClue("dayofweek") {
+ val dow = sql(s"SELECT dayofweek(date1) FROM $tbl WHERE name = 'alex'")
+ .collect().head.getInt(0)
+ val df = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = $dow")
+ checkFilterPushed(df)
+ val rows = df.collect()
+ assert(rows.length === 1)
+ assert(rows(0).getString(0) === "alex")
+ }
+
+ withClue("yearofweek") {
+ val yow = sql(s"SELECT extract(YEAROFWEEK from date1) FROM $tbl WHERE name = 'tom'")
+ .collect().head.getInt(0)
+ val df = sql(s"SELECT name FROM $tbl WHERE extract(YEAROFWEEK from date1) = $yow")
+ checkFilterPushed(df, false)
+ val rows = df.collect()
+ assert(rows.length === 1)
+ assert(rows(0).getString(0) === "tom")
+ }
+
+ withClue("second") {
+ val df = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5")
+ checkFilterPushed(df, false)
+ val rows = df.collect()
+ assert(rows.length === 2)
+ assert(rows(0).getString(0) === "amy")
+ assert(rows(1).getString(0) === "alex")
+ }
val df9 = sql(s"SELECT name FROM $tbl WHERE " +
"dayofyear(date1) > 100 order by dayofyear(date1) limit 1")
@@ -253,11 +306,18 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
*/
@DockerTest
class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite {
- override def defaultMetadata(dataType: DataType = StringType): Metadata = new MetadataBuilder()
- .putLong("scale", 0)
- .putBoolean("isTimestampNTZ", false)
- .putBoolean("isSigned", true)
- .build()
+ override val jdbcClientTypes: JdbcClientTypes =
+ JdbcClientTypes(INTEGER = "INTEGER", STRING = "LONGTEXT")
+
+ override def defaultMetadata(
+ dataType: DataType = StringType,
+ jdbcClientType: String = jdbcClientTypes.STRING): Metadata =
+ new MetadataBuilder()
+ .putLong("scale", 0)
+ .putBoolean("isTimestampNTZ", false)
+ .putBoolean("isSigned", true)
+ .putString("jdbcClientType", jdbcClientType)
+ .build()
override val db = new MySQLDatabaseOnDocker {
override def getJdbcUrl(ip: String, port: Int): String =
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
index 2c97a588670a8..6499db2cc03b4 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
@@ -75,12 +75,24 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
override val namespaceOpt: Option[String] = Some("SYSTEM")
override val db = new OracleDatabaseOnDocker
- override def defaultMetadata(dataType: DataType): Metadata = new MetadataBuilder()
- .putLong("scale", 0)
- .putBoolean("isTimestampNTZ", false)
- .putBoolean("isSigned", dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType])
- .putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, "varchar(255)")
- .build()
+ object JdbcClientTypes {
+ val NUMBER = "NUMBER"
+ val STRING = "VARCHAR2"
+ }
+
+ override def defaultMetadata(
+ dataType: DataType = StringType,
+ jdbcClientType: String = JdbcClientTypes.STRING): Metadata =
+ new MetadataBuilder()
+ .putLong("scale", 0)
+ .putBoolean("isTimestampNTZ", false)
+ .putBoolean(
+ "isSigned",
+ dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType])
+ .putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, "varchar(255)")
+ .putString("jdbcClientType", jdbcClientType)
+ .build()
+
override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName)
@@ -105,12 +117,20 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
sql(s"CREATE TABLE $tbl (ID INTEGER)")
var t = spark.table(tbl)
var expectedSchema = new StructType()
- .add("ID", DecimalType(10, 0), true, super.defaultMetadata(DecimalType(10, 0)))
+ .add(
+ "ID",
+ DecimalType(10, 0),
+ true,
+ super.defaultMetadata(DecimalType(10, 0), JdbcClientTypes.NUMBER))
assert(t.schema === expectedSchema)
sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE LONG")
t = spark.table(tbl)
expectedSchema = new StructType()
- .add("ID", DecimalType(19, 0), true, super.defaultMetadata(DecimalType(19, 0)))
+ .add(
+ "ID",
+ DecimalType(19, 0),
+ true,
+ super.defaultMetadata(DecimalType(19, 0), JdbcClientTypes.NUMBER))
assert(t.schema === expectedSchema)
// Update column type from LONG to INTEGER
val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER"
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
index af3c17dc98ae8..5211b5d328e98 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
@@ -22,7 +22,6 @@ import java.sql.Connection
import org.apache.spark.{SparkConf, SparkSQLException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
-import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.jdbc.PostgresDatabaseOnDocker
import org.apache.spark.sql.types._
@@ -39,6 +38,22 @@ import org.apache.spark.tags.DockerTest
class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
override val catalogName: String = "postgresql"
override val db = new PostgresDatabaseOnDocker
+
+ object JdbcClientTypes {
+ val INTEGER = "int4"
+ val STRING = "text"
+ }
+
+ override def defaultMetadata(
+ dataType: DataType = StringType,
+ jdbcClientType: String = JdbcClientTypes.STRING): Metadata =
+ new MetadataBuilder()
+ .putLong("scale", 0)
+ .putBoolean("isTimestampNTZ", false)
+ .putBoolean("isSigned", dataType.isInstanceOf[NumericType])
+ .putString("jdbcClientType", jdbcClientType)
+ .build()
+
override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort))
@@ -194,7 +209,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
sql(s"CREATE TABLE $tbl (ID INTEGER)")
var t = spark.table(tbl)
var expectedSchema = new StructType()
- .add("ID", IntegerType, true, defaultMetadata(IntegerType))
+ .add("ID", IntegerType, true, defaultMetadata(IntegerType, JdbcClientTypes.INTEGER))
assert(t.schema === expectedSchema)
sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING")
t = spark.table(tbl)
@@ -223,7 +238,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
s" TBLPROPERTIES('TABLESPACE'='pg_default')")
val t = spark.table(tbl)
val expectedSchema = new StructType()
- .add("ID", IntegerType, true, defaultMetadata(IntegerType))
+ .add("ID", IntegerType, true, defaultMetadata(IntegerType, JdbcClientTypes.INTEGER))
assert(t.schema === expectedSchema)
}
@@ -250,7 +265,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
test("SPARK-49695: Postgres fix xor push-down") {
val df = spark.sql(s"select dept, name from $catalogName.employee where dept ^ 6 = 0")
val rows = df.collect()
- assert(!df.queryExecution.sparkPlan.exists(_.isInstanceOf[FilterExec]))
+ checkFilterPushed(df)
assert(rows.length == 1)
assert(rows(0).getInt(0) === 6)
assert(rows(0).getString(1) === "jen")
@@ -374,4 +389,28 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
parameters = Map("pos" -> "0", "type" -> "\"ARRAY>\"")
)
}
+
+ test("SPARK-51321: Postgres pushdown for RPAD expression on string column") {
+ val df = sql(
+ s"""|SELECT name FROM $catalogName.employee
+ |WHERE rpad(name, 10, 'x') = 'amyxxxxxxx'
+ |""".stripMargin
+ )
+ checkFilterPushed(df)
+ val rows = df.collect()
+ assert(rows.length === 1)
+ assert(rows(0).getString(0) === "amy")
+ }
+
+ test("SPARK-51321: Postgres pushdown for LPAD expression on string column") {
+ val df = sql(
+ s"""|SELECT name FROM $catalogName.employee
+ |WHERE lpad(name, 10, 'x') = 'xxxxxxxamy'
+ |""".stripMargin
+ )
+ checkFilterPushed(df)
+ val rows = df.collect()
+ assert(rows.length === 1)
+ assert(rows(0).getString(0) === "amy")
+ }
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index 47e1c8c06dd45..51862ae1535cf 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -48,23 +48,52 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
def notSupportsTableComment: Boolean = false
- def defaultMetadata(dataType: DataType = StringType): Metadata = new MetadataBuilder()
+ def defaultMetadata(
+ dataType: DataType = StringType,
+ jdbcClientType: String = "STRING"): Metadata = new MetadataBuilder()
.putLong("scale", 0)
.putBoolean("isTimestampNTZ", false)
.putBoolean("isSigned", dataType.isInstanceOf[NumericType])
+ .putString("jdbcClientType", jdbcClientType)
.build()
+ /**
+ * Returns a copy of the given [[StructType]] with the specified metadata key removed
+ * from all of its fields.
+ */
+ def removeMetadataFromAllFields(structType: StructType, metadataKey: String): StructType = {
+ val updatedFields = structType.fields.map { field =>
+ val oldMetadata = field.metadata
+ val newMetadataBuilder = new MetadataBuilder()
+ .withMetadata(oldMetadata)
+ .remove(metadataKey)
+ field.copy(metadata = newMetadataBuilder.build())
+ }
+ StructType(updatedFields)
+ }
+
def testUpdateColumnNullability(tbl: String): Unit = {
sql(s"CREATE TABLE $catalogName.alt_table (ID STRING NOT NULL)")
var t = spark.table(s"$catalogName.alt_table")
// nullable is true in the expectedSchema because Spark always sets nullable to true
// regardless of the JDBC metadata https://github.com/apache/spark/pull/18445
var expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata())
- assert(t.schema === expectedSchema)
+ // If function is not overriden we don't want to compare external engine types
+ var expectedSchemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(expectedSchema, "jdbcClientType")
+ var schemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(t.schema, "jdbcClientType")
+ assert(schemaWithoutJdbcClientType === expectedSchemaWithoutJdbcClientType)
sql(s"ALTER TABLE $catalogName.alt_table ALTER COLUMN ID DROP NOT NULL")
t = spark.table(s"$catalogName.alt_table")
expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata())
- assert(t.schema === expectedSchema)
+
+ // If function is not overriden we don't want to compare external engine types
+ expectedSchemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(expectedSchema, "jdbcClientType")
+ schemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(t.schema, "jdbcClientType")
+ assert(schemaWithoutJdbcClientType === expectedSchemaWithoutJdbcClientType)
// Update nullability of not existing column
val sqlText = s"ALTER TABLE $catalogName.alt_table ALTER COLUMN bad_column DROP NOT NULL"
checkError(
@@ -85,7 +114,13 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val expectedSchema = new StructType().add("RENAMED", StringType, true, defaultMetadata())
.add("ID1", StringType, true, defaultMetadata())
.add("ID2", StringType, true, defaultMetadata())
- assert(t.schema === expectedSchema)
+
+ // If function is not overriden we don't want to compare external engine types
+ val expectedSchemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(expectedSchema, "jdbcClientType")
+ val schemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(t.schema, "jdbcClientType")
+ assert(schemaWithoutJdbcClientType === expectedSchemaWithoutJdbcClientType)
}
def testCreateTableWithProperty(tbl: String): Unit = {}
@@ -109,18 +144,30 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
var t = spark.table(s"$catalogName.alt_table")
var expectedSchema = new StructType()
.add("ID", StringType, true, defaultMetadata())
- assert(t.schema === expectedSchema)
+ var expectedSchemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(expectedSchema, "jdbcClientType")
+ var schemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(t.schema, "jdbcClientType")
+ assert(schemaWithoutJdbcClientType === expectedSchemaWithoutJdbcClientType)
sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C1 STRING, C2 STRING)")
t = spark.table(s"$catalogName.alt_table")
expectedSchema = expectedSchema
.add("C1", StringType, true, defaultMetadata())
.add("C2", StringType, true, defaultMetadata())
- assert(t.schema === expectedSchema)
+ expectedSchemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(expectedSchema, "jdbcClientType")
+ schemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(t.schema, "jdbcClientType")
+ assert(schemaWithoutJdbcClientType === expectedSchemaWithoutJdbcClientType)
sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C3 STRING)")
t = spark.table(s"$catalogName.alt_table")
expectedSchema = expectedSchema
.add("C3", StringType, true, defaultMetadata())
- assert(t.schema === expectedSchema)
+ expectedSchemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(expectedSchema, "jdbcClientType")
+ schemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(t.schema, "jdbcClientType")
+ assert(schemaWithoutJdbcClientType === expectedSchemaWithoutJdbcClientType)
// Add already existing column
checkError(
exception = intercept[AnalysisException] {
@@ -141,7 +188,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val e = intercept[AnalysisException] {
sql(s"ALTER TABLE $catalogName.not_existing_table ADD COLUMNS (C4 STRING)")
}
- checkErrorFailedJDBC(e, "FAILED_JDBC.LOAD_TABLE", "not_existing_table")
+ checkErrorTableNotFound(
+ e,
+ s"`$catalogName`.`not_existing_table`",
+ ExpectedContext(
+ s"$catalogName.not_existing_table", 12, 11 + s"$catalogName.not_existing_table".length))
}
test("SPARK-33034: ALTER TABLE ... drop column") {
@@ -152,7 +203,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val t = spark.table(s"$catalogName.alt_table")
val expectedSchema = new StructType()
.add("C2", StringType, true, defaultMetadata())
- assert(t.schema === expectedSchema)
+ val expectedSchemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(expectedSchema, "jdbcClientType")
+ val schemaWithoutJdbcClientType =
+ removeMetadataFromAllFields(t.schema, "jdbcClientType")
+ assert(schemaWithoutJdbcClientType === expectedSchemaWithoutJdbcClientType)
// Drop not existing column
val sqlText = s"ALTER TABLE $catalogName.alt_table DROP COLUMN bad_column"
checkError(
@@ -170,7 +225,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val e = intercept[AnalysisException] {
sql(s"ALTER TABLE $catalogName.not_existing_table DROP COLUMN C1")
}
- checkErrorFailedJDBC(e, "FAILED_JDBC.LOAD_TABLE", "not_existing_table")
+ checkErrorTableNotFound(
+ e,
+ s"`$catalogName`.`not_existing_table`",
+ ExpectedContext(
+ s"$catalogName.not_existing_table", 12, 11 + s"$catalogName.not_existing_table".length))
}
test("SPARK-33034: ALTER TABLE ... update column type") {
@@ -193,7 +252,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val e = intercept[AnalysisException] {
sql(s"ALTER TABLE $catalogName.not_existing_table ALTER COLUMN id TYPE DOUBLE")
}
- checkErrorFailedJDBC(e, "FAILED_JDBC.LOAD_TABLE", "not_existing_table")
+ checkErrorTableNotFound(
+ e,
+ s"`$catalogName`.`not_existing_table`",
+ ExpectedContext(
+ s"$catalogName.not_existing_table", 12, 11 + s"$catalogName.not_existing_table".length))
}
test("SPARK-33034: ALTER TABLE ... rename column") {
@@ -221,7 +284,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val e = intercept[AnalysisException] {
sql(s"ALTER TABLE $catalogName.not_existing_table RENAME COLUMN ID TO C")
}
- checkErrorFailedJDBC(e, "FAILED_JDBC.LOAD_TABLE", "not_existing_table")
+ checkErrorTableNotFound(
+ e,
+ s"`$catalogName`.`not_existing_table`",
+ ExpectedContext(
+ s"$catalogName.not_existing_table", 12, 11 + s"$catalogName.not_existing_table".length))
}
test("SPARK-33034: ALTER TABLE ... update column nullability") {
@@ -232,7 +299,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val e = intercept[AnalysisException] {
sql(s"ALTER TABLE $catalogName.not_existing_table ALTER COLUMN ID DROP NOT NULL")
}
- checkErrorFailedJDBC(e, "FAILED_JDBC.LOAD_TABLE", "not_existing_table")
+ checkErrorTableNotFound(
+ e,
+ s"`$catalogName`.`not_existing_table`",
+ ExpectedContext(
+ s"$catalogName.not_existing_table", 12, 11 + s"$catalogName.not_existing_table".length))
}
test("CREATE TABLE with table comment") {
diff --git a/connector/kafka-0-10-sql/src/main/resources/error/kafka-error-conditions.json b/connector/kafka-0-10-sql/src/main/resources/error/kafka-error-conditions.json
index 42905c06ca66c..d6a7aa19d0307 100644
--- a/connector/kafka-0-10-sql/src/main/resources/error/kafka-error-conditions.json
+++ b/connector/kafka-0-10-sql/src/main/resources/error/kafka-error-conditions.json
@@ -37,6 +37,11 @@
"Specified: Assigned: "
]
},
+ "KAFKA_NULL_TOPIC_IN_DATA": {
+ "message" : [
+ "The Kafka message data sent to the producer contains a null topic. Use the `topic` option for setting a default topic."
+ ]
+ },
"KAFKA_DATA_LOSS" : {
"message" : [
"Some data may have been lost because they are not available in Kafka any more;",
diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala
index a31fc56bf8920..a6eb13e68c19a 100644
--- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala
+++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala
@@ -177,6 +177,12 @@ object KafkaExceptions {
"specifiedPartitions" -> specifiedPartitions.toString,
"assignedPartitions" -> assignedPartitions.toString))
}
+
+ def nullTopicInData(): KafkaIllegalStateException = {
+ new KafkaIllegalStateException(
+ errorClass = "KAFKA_NULL_TOPIC_IN_DATA",
+ messageParameters = Map.empty)
+ }
}
/**
diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
index e8f98262a8972..83663386856df 100644
--- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
+++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
@@ -27,6 +27,7 @@ import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection}
+import org.apache.spark.sql.kafka010.KafkaExceptions.nullTopicInData
import org.apache.spark.sql.kafka010.producer.{CachedKafkaProducer, InternalKafkaProducerPool}
import org.apache.spark.sql.types.BinaryType
@@ -95,8 +96,7 @@ private[kafka010] abstract class KafkaRowWriter(
val key = projectedRow.getBinary(1)
val value = projectedRow.getBinary(2)
if (topic == null) {
- throw new NullPointerException(s"null topic present in the data. Use the " +
- s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
+ throw nullTopicInData()
}
val partition: Integer =
if (projectedRow.isNullAt(4)) null else projectedRow.getInt(4)
diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 22eeae97874b1..e738abf21f597 100644
--- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -1614,7 +1614,7 @@ abstract class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBa
testStream(kafka)(
makeSureGetOffsetCalled,
AssertOnQuery { query =>
- query.logicalPlan.collect {
+ query.logicalPlan.collectFirst {
case StreamingExecutionRelation(_: KafkaSource, _, _) => true
}.nonEmpty
}
diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala
index 6087447fa3045..1a884533e818b 100644
--- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala
+++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala
@@ -637,7 +637,7 @@ class KafkaRelationSuiteV1 extends KafkaRelationSuiteBase {
test("V1 Source is used when set through SQLConf") {
val topic = newTopic()
val df = createDF(topic)
- assert(df.logicalPlan.collect {
+ assert(df.logicalPlan.collectFirst {
case _: LogicalRelation => true
}.nonEmpty)
}
@@ -652,7 +652,7 @@ class KafkaRelationSuiteV2 extends KafkaRelationSuiteBase {
test("V2 Source is used when set through SQLConf") {
val topic = newTopic()
val df = createDF(topic)
- assert(df.logicalPlan.collect {
+ assert(df.logicalPlan.collectFirst {
case _: DataSourceV2Relation => true
}.nonEmpty)
}
diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
index 5566785c4d56d..82edba59995ec 100644
--- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
+++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
@@ -29,7 +29,7 @@ import org.apache.kafka.common.Cluster
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkConf, SparkContext, SparkException, TestUtils}
+import org.apache.spark.{SparkConf, SparkContext, TestUtils}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection}
import org.apache.spark.sql.execution.streaming.{MemoryStream, MemoryStreamBase}
@@ -491,14 +491,17 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase {
test("batch - null topic field value, and no topic option") {
val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value")
- val ex = intercept[SparkException] {
+ val ex = intercept[KafkaIllegalStateException] {
df.write
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.mode("append")
.save()
}
- TestUtils.assertExceptionMsg(ex, "null topic present in the data")
+ checkError(
+ exception = ex,
+ condition = "KAFKA_NULL_TOPIC_IN_DATA"
+ )
}
protected def testUnsupportedSaveModes(msg: (SaveMode) => Seq[String]): Unit = {
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
index 8805d935093e8..0564eee1602a7 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
@@ -55,4 +55,35 @@ private[sql] case class CatalystDataToProtobuf(
override protected def withNewChildInternal(newChild: Expression): CatalystDataToProtobuf =
copy(child = newChild)
+
+ override def equals(that: Any): Boolean = {
+ that match {
+ case that: CatalystDataToProtobuf =>
+ this.child == that.child &&
+ this.messageName == that.messageName &&
+ (
+ (this.binaryFileDescriptorSet.isEmpty && that.binaryFileDescriptorSet.isEmpty) ||
+ (
+ this.binaryFileDescriptorSet.nonEmpty && that.binaryFileDescriptorSet.nonEmpty &&
+ this.binaryFileDescriptorSet.get.sameElements(that.binaryFileDescriptorSet.get)
+ )
+ ) &&
+ this.options == that.options
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = {
+ val prime = 31
+ var result = 1
+ var i = 0
+ while (i < binaryFileDescriptorSet.map(_.length).getOrElse(0)) {
+ result = prime * result + binaryFileDescriptorSet.get.apply(i).hashCode
+ i += 1
+ }
+ result = prime * result + child.hashCode
+ result = prime * result + messageName.hashCode
+ result = prime * result + options.hashCode
+ result
+ }
}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
index a182ac854b28b..b3225d61eb01a 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
@@ -142,4 +142,35 @@ private[sql] case class ProtobufDataToCatalyst(
override protected def withNewChildInternal(newChild: Expression): ProtobufDataToCatalyst =
copy(child = newChild)
+
+ override def equals(that: Any): Boolean = {
+ that match {
+ case that: ProtobufDataToCatalyst =>
+ this.child == that.child &&
+ this.messageName == that.messageName &&
+ (
+ (this.binaryFileDescriptorSet.isEmpty && that.binaryFileDescriptorSet.isEmpty) ||
+ (
+ this.binaryFileDescriptorSet.nonEmpty && that.binaryFileDescriptorSet.nonEmpty &&
+ this.binaryFileDescriptorSet.get.sameElements(that.binaryFileDescriptorSet.get)
+ )
+ ) &&
+ this.options == that.options
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = {
+ val prime = 31
+ var result = 1
+ var i = 0
+ while (i < binaryFileDescriptorSet.map(_.length).getOrElse(0)) {
+ result = prime * result + binaryFileDescriptorSet.get.apply(i).hashCode
+ i += 1
+ }
+ result = prime * result + child.hashCode
+ result = prime * result + messageName.hashCode
+ result = prime * result + options.hashCode
+ result
+ }
}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
index abae1d622d3cf..1802bceee1dff 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -244,4 +244,195 @@ class ProtobufCatalystDataConversionSuite
testFileDesc, "org.apache.spark.sql.protobuf.protos.BytesMsg")
assert(withFullName.findFieldByName("bytes_type") != null)
}
+
+ test("CatalystDataToProtobuf equals") {
+ val catalystDataToProtobuf = generateCatalystDataToProtobuf()
+
+ assert(
+ catalystDataToProtobuf
+ == catalystDataToProtobuf.copy()
+ )
+ assert(
+ catalystDataToProtobuf
+ != catalystDataToProtobuf.copy(options = Map("mode" -> "FAILFAST"))
+ )
+ assert(
+ catalystDataToProtobuf
+ != catalystDataToProtobuf.copy(messageName = "otherMessage")
+ )
+ assert(
+ catalystDataToProtobuf
+ != catalystDataToProtobuf.copy(child = Literal.create(0, IntegerType))
+ )
+ assert(
+ catalystDataToProtobuf
+ != catalystDataToProtobuf.copy(binaryFileDescriptorSet = None)
+ )
+
+ val testFileDescCopy = new Array[Byte](testFileDesc.length)
+ testFileDesc.copyToArray(testFileDescCopy)
+ assert(
+ catalystDataToProtobuf
+ == catalystDataToProtobuf.copy(binaryFileDescriptorSet = Some(testFileDescCopy))
+ )
+
+ testFileDescCopy(0) = '0'
+ assert(
+ catalystDataToProtobuf
+ != catalystDataToProtobuf.copy(binaryFileDescriptorSet = Some(testFileDescCopy))
+ )
+ }
+
+ test("CatalystDataToProtobuf hashCode") {
+ val catalystDataToProtobuf = generateCatalystDataToProtobuf()
+
+ assert(
+ catalystDataToProtobuf
+ .copy(options = Map("mode" -> "FAILFAST"))
+ .hashCode != catalystDataToProtobuf.hashCode
+ )
+ assert(
+ catalystDataToProtobuf
+ .copy(messageName = "otherMessage")
+ .hashCode != catalystDataToProtobuf.hashCode
+ )
+ assert(
+ catalystDataToProtobuf
+ .copy(child = Literal.create(0, IntegerType))
+ .hashCode != catalystDataToProtobuf.hashCode
+ )
+ assert(
+ catalystDataToProtobuf
+ .copy(binaryFileDescriptorSet = None)
+ .hashCode != catalystDataToProtobuf.hashCode
+ )
+
+ val testFileDescCopy = new Array[Byte](testFileDesc.length)
+ testFileDesc.copyToArray(testFileDescCopy)
+ assert(
+ catalystDataToProtobuf
+ .copy(
+ binaryFileDescriptorSet = Some(testFileDescCopy)
+ )
+ .hashCode == catalystDataToProtobuf.hashCode
+ )
+
+ testFileDescCopy(0) = '0'
+ assert(
+ catalystDataToProtobuf
+ .copy(
+ binaryFileDescriptorSet = Some(testFileDescCopy)
+ )
+ .hashCode != catalystDataToProtobuf.hashCode
+ )
+ }
+
+ test("ProtobufDataToCatalyst equals") {
+ val catalystDataToProtobuf = generateCatalystDataToProtobuf()
+ val protobufDataToCatalyst = ProtobufDataToCatalyst(
+ catalystDataToProtobuf,
+ "message",
+ Some(testFileDesc),
+ Map("mode" -> "PERMISSIVE")
+ )
+
+ assert(
+ protobufDataToCatalyst
+ == protobufDataToCatalyst.copy()
+ )
+ assert(
+ protobufDataToCatalyst
+ != protobufDataToCatalyst.copy(options = Map("mode" -> "FAILFAST"))
+ )
+ assert(
+ protobufDataToCatalyst
+ != protobufDataToCatalyst.copy(messageName = "otherMessage")
+ )
+ assert(
+ protobufDataToCatalyst
+ != protobufDataToCatalyst.copy(child = Literal.create(0, IntegerType))
+ )
+ assert(
+ protobufDataToCatalyst
+ != protobufDataToCatalyst.copy(binaryFileDescriptorSet = None)
+ )
+
+ val testFileDescCopy = new Array[Byte](testFileDesc.length)
+ testFileDesc.copyToArray(testFileDescCopy)
+ assert(
+ protobufDataToCatalyst
+ == protobufDataToCatalyst.copy(binaryFileDescriptorSet = Some(testFileDescCopy))
+ )
+
+ testFileDescCopy(0) = '0'
+ assert(
+ protobufDataToCatalyst
+ != protobufDataToCatalyst.copy(binaryFileDescriptorSet = Some(testFileDescCopy))
+ )
+ }
+
+ test("ProtobufDataToCatalyst hashCode") {
+ val catalystDataToProtobuf = generateCatalystDataToProtobuf()
+ val protobufDataToCatalyst = ProtobufDataToCatalyst(
+ catalystDataToProtobuf,
+ "message",
+ Some(testFileDesc),
+ Map("mode" -> "PERMISSIVE")
+ )
+
+ assert(
+ protobufDataToCatalyst
+ .copy(options = Map("mode" -> "FAILFAST"))
+ .hashCode != protobufDataToCatalyst.hashCode
+ )
+ assert(
+ protobufDataToCatalyst
+ .copy(messageName = "otherMessage")
+ .hashCode != protobufDataToCatalyst.hashCode
+ )
+ assert(
+ protobufDataToCatalyst
+ .copy(child = Literal.create(0, IntegerType))
+ .hashCode != protobufDataToCatalyst.hashCode
+ )
+ assert(
+ protobufDataToCatalyst
+ .copy(binaryFileDescriptorSet = None)
+ .hashCode != protobufDataToCatalyst.hashCode
+ )
+
+ val testFileDescCopy = new Array[Byte](testFileDesc.length)
+ testFileDesc.copyToArray(testFileDescCopy)
+ assert(
+ protobufDataToCatalyst
+ .copy(
+ binaryFileDescriptorSet = Some(testFileDescCopy)
+ )
+ .hashCode == protobufDataToCatalyst.hashCode
+ )
+
+ testFileDescCopy(0) = '0'
+ assert(
+ protobufDataToCatalyst
+ .copy(
+ binaryFileDescriptorSet = Some(testFileDescCopy)
+ )
+ .hashCode != protobufDataToCatalyst.hashCode
+ )
+ }
+
+ private def generateCatalystDataToProtobuf() = {
+ val schema = StructType(
+ Seq(
+ StructField("a", StringType),
+ StructField("b", IntegerType)
+ )
+ )
+ val messageName = "message"
+ val data = RandomDataGenerator.randomRow(new scala.util.Random(3), schema)
+ val converter = CatalystTypeConverters.createToCatalystConverter(schema)
+ val dataLiteral = Literal.create(converter(data), schema)
+
+ CatalystDataToProtobuf(dataLiteral, messageName, Some(testFileDesc))
+ }
}
diff --git a/core/benchmarks/KryoBenchmark-jdk21-results.txt b/core/benchmarks/KryoBenchmark-jdk21-results.txt
index 704a167c62d0a..ca03441d01a87 100644
--- a/core/benchmarks/KryoBenchmark-jdk21-results.txt
+++ b/core/benchmarks/KryoBenchmark-jdk21-results.txt
@@ -2,27 +2,27 @@
Benchmark Kryo Unsafe vs safe Serialization
================================================================================================
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
AMD EPYC 7763 64-Core Processor
Benchmark Kryo Unsafe vs safe Serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------------------------------------
-basicTypes: Int with unsafe:true 166 168 3 6.0 165.5 1.0X
-basicTypes: Long with unsafe:true 178 182 5 5.6 178.2 0.9X
-basicTypes: Float with unsafe:true 185 189 6 5.4 184.9 0.9X
-basicTypes: Double with unsafe:true 183 188 9 5.5 183.2 0.9X
-Array: Int with unsafe:true 1 1 0 763.7 1.3 126.4X
-Array: Long with unsafe:true 2 2 0 447.3 2.2 74.1X
-Array: Float with unsafe:true 1 1 0 753.7 1.3 124.8X
-Array: Double with unsafe:true 2 2 0 457.5 2.2 75.7X
-Map of string->Double with unsafe:true 28 28 0 36.2 27.6 6.0X
-basicTypes: Int with unsafe:false 203 204 1 4.9 203.1 0.8X
-basicTypes: Long with unsafe:false 223 224 1 4.5 222.8 0.7X
-basicTypes: Float with unsafe:false 206 207 1 4.9 205.8 0.8X
-basicTypes: Double with unsafe:false 204 205 1 4.9 204.1 0.8X
-Array: Int with unsafe:false 13 13 0 79.5 12.6 13.2X
-Array: Long with unsafe:false 21 22 1 46.6 21.5 7.7X
-Array: Float with unsafe:false 13 13 0 78.6 12.7 13.0X
-Array: Double with unsafe:false 15 15 0 67.8 14.8 11.2X
-Map of string->Double with unsafe:false 28 30 1 35.3 28.3 5.8X
+basicTypes: Int with unsafe:true 166 169 6 6.0 165.7 1.0X
+basicTypes: Long with unsafe:true 178 182 4 5.6 177.5 0.9X
+basicTypes: Float with unsafe:true 183 191 9 5.5 182.7 0.9X
+basicTypes: Double with unsafe:true 186 193 7 5.4 186.4 0.9X
+Array: Int with unsafe:true 1 1 0 745.8 1.3 123.5X
+Array: Long with unsafe:true 2 3 0 451.7 2.2 74.8X
+Array: Float with unsafe:true 1 1 0 743.0 1.3 123.1X
+Array: Double with unsafe:true 2 2 0 475.6 2.1 78.8X
+Map of string->Double with unsafe:true 27 28 1 37.0 27.0 6.1X
+basicTypes: Int with unsafe:false 198 199 1 5.1 197.7 0.8X
+basicTypes: Long with unsafe:false 220 221 1 4.5 219.8 0.8X
+basicTypes: Float with unsafe:false 206 208 1 4.8 206.3 0.8X
+basicTypes: Double with unsafe:false 222 225 2 4.5 221.9 0.7X
+Array: Int with unsafe:false 13 14 1 78.0 12.8 12.9X
+Array: Long with unsafe:false 21 21 1 48.2 20.8 8.0X
+Array: Float with unsafe:false 6 6 0 178.9 5.6 29.6X
+Array: Double with unsafe:false 15 16 0 65.3 15.3 10.8X
+Map of string->Double with unsafe:false 28 29 2 35.2 28.4 5.8X
diff --git a/core/benchmarks/KryoBenchmark-results.txt b/core/benchmarks/KryoBenchmark-results.txt
index 7c8ffceea65c4..bffc9eafb79e3 100644
--- a/core/benchmarks/KryoBenchmark-results.txt
+++ b/core/benchmarks/KryoBenchmark-results.txt
@@ -2,27 +2,27 @@
Benchmark Kryo Unsafe vs safe Serialization
================================================================================================
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
AMD EPYC 7763 64-Core Processor
Benchmark Kryo Unsafe vs safe Serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------------------------------------
-basicTypes: Int with unsafe:true 177 180 3 5.7 176.9 1.0X
-basicTypes: Long with unsafe:true 188 190 1 5.3 188.1 0.9X
-basicTypes: Float with unsafe:true 190 192 2 5.3 190.2 0.9X
-basicTypes: Double with unsafe:true 199 201 4 5.0 199.0 0.9X
-Array: Int with unsafe:true 1 1 0 783.6 1.3 138.6X
-Array: Long with unsafe:true 2 2 0 491.6 2.0 87.0X
-Array: Float with unsafe:true 1 1 0 757.8 1.3 134.1X
-Array: Double with unsafe:true 2 2 0 497.5 2.0 88.0X
-Map of string->Double with unsafe:true 26 27 3 37.9 26.4 6.7X
-basicTypes: Int with unsafe:false 230 232 1 4.3 230.1 0.8X
-basicTypes: Long with unsafe:false 267 268 1 3.7 267.0 0.7X
-basicTypes: Float with unsafe:false 229 230 1 4.4 229.2 0.8X
-basicTypes: Double with unsafe:false 216 217 1 4.6 216.3 0.8X
-Array: Int with unsafe:false 15 15 0 68.8 14.5 12.2X
-Array: Long with unsafe:false 22 22 0 46.1 21.7 8.2X
-Array: Float with unsafe:false 6 6 0 169.6 5.9 30.0X
-Array: Double with unsafe:false 9 9 0 108.1 9.3 19.1X
-Map of string->Double with unsafe:false 28 28 1 36.3 27.5 6.4X
+basicTypes: Int with unsafe:true 172 174 1 5.8 171.9 1.0X
+basicTypes: Long with unsafe:true 196 197 0 5.1 196.2 0.9X
+basicTypes: Float with unsafe:true 193 195 2 5.2 192.7 0.9X
+basicTypes: Double with unsafe:true 193 194 1 5.2 193.2 0.9X
+Array: Int with unsafe:true 1 1 0 715.2 1.4 122.9X
+Array: Long with unsafe:true 2 2 0 474.2 2.1 81.5X
+Array: Float with unsafe:true 1 1 0 718.2 1.4 123.5X
+Array: Double with unsafe:true 2 2 0 475.8 2.1 81.8X
+Map of string->Double with unsafe:true 27 28 0 36.7 27.2 6.3X
+basicTypes: Int with unsafe:false 207 209 5 4.8 207.3 0.8X
+basicTypes: Long with unsafe:false 239 241 2 4.2 238.9 0.7X
+basicTypes: Float with unsafe:false 215 217 2 4.6 215.4 0.8X
+basicTypes: Double with unsafe:false 220 225 7 4.5 220.2 0.8X
+Array: Int with unsafe:false 16 20 7 63.4 15.8 10.9X
+Array: Long with unsafe:false 22 22 0 45.9 21.8 7.9X
+Array: Float with unsafe:false 6 6 1 170.0 5.9 29.2X
+Array: Double with unsafe:false 10 10 0 98.6 10.1 16.9X
+Map of string->Double with unsafe:false 28 29 1 35.9 27.9 6.2X
diff --git a/core/benchmarks/KryoIteratorBenchmark-jdk21-results.txt b/core/benchmarks/KryoIteratorBenchmark-jdk21-results.txt
index 21b4df268d265..835b16e24d95f 100644
--- a/core/benchmarks/KryoIteratorBenchmark-jdk21-results.txt
+++ b/core/benchmarks/KryoIteratorBenchmark-jdk21-results.txt
@@ -2,27 +2,27 @@
Benchmark of kryo asIterator on deserialization stream
================================================================================================
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
AMD EPYC 7763 64-Core Processor
Benchmark of kryo asIterator on deserialization stream: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------------------------------------------------------
-Colletion of int with 1 elements, useIterator: true 6 6 0 1.7 584.3 1.0X
-Colletion of int with 10 elements, useIterator: true 13 14 0 0.8 1330.3 0.4X
-Colletion of int with 100 elements, useIterator: true 83 84 0 0.1 8310.8 0.1X
-Colletion of string with 1 elements, useIterator: true 8 8 0 1.3 763.6 0.8X
-Colletion of string with 10 elements, useIterator: true 22 23 0 0.5 2209.8 0.3X
-Colletion of string with 100 elements, useIterator: true 163 164 2 0.1 16262.3 0.0X
-Colletion of Array[int] with 1 elements, useIterator: true 7 8 0 1.4 730.9 0.8X
-Colletion of Array[int] with 10 elements, useIterator: true 20 20 0 0.5 1990.1 0.3X
-Colletion of Array[int] with 100 elements, useIterator: true 155 156 1 0.1 15527.8 0.0X
-Colletion of int with 1 elements, useIterator: false 6 6 0 1.7 599.4 1.0X
-Colletion of int with 10 elements, useIterator: false 13 14 0 0.7 1337.3 0.4X
-Colletion of int with 100 elements, useIterator: false 83 84 1 0.1 8320.9 0.1X
-Colletion of string with 1 elements, useIterator: false 7 8 0 1.4 731.4 0.8X
-Colletion of string with 10 elements, useIterator: false 22 22 0 0.5 2160.9 0.3X
-Colletion of string with 100 elements, useIterator: false 170 171 0 0.1 17015.3 0.0X
-Colletion of Array[int] with 1 elements, useIterator: false 7 8 1 1.4 710.4 0.8X
-Colletion of Array[int] with 10 elements, useIterator: false 19 20 0 0.5 1925.0 0.3X
-Colletion of Array[int] with 100 elements, useIterator: false 143 144 2 0.1 14267.3 0.0X
+Colletion of int with 1 elements, useIterator: true 6 7 0 1.5 645.9 1.0X
+Colletion of int with 10 elements, useIterator: true 13 14 0 0.8 1330.8 0.5X
+Colletion of int with 100 elements, useIterator: true 80 81 1 0.1 7987.1 0.1X
+Colletion of string with 1 elements, useIterator: true 8 8 0 1.3 787.4 0.8X
+Colletion of string with 10 elements, useIterator: true 21 21 0 0.5 2113.7 0.3X
+Colletion of string with 100 elements, useIterator: true 161 162 1 0.1 16108.9 0.0X
+Colletion of Array[int] with 1 elements, useIterator: true 7 8 0 1.3 747.8 0.9X
+Colletion of Array[int] with 10 elements, useIterator: true 19 19 0 0.5 1879.8 0.3X
+Colletion of Array[int] with 100 elements, useIterator: true 140 141 1 0.1 14008.3 0.0X
+Colletion of int with 1 elements, useIterator: false 6 7 0 1.6 642.6 1.0X
+Colletion of int with 10 elements, useIterator: false 14 15 1 0.7 1414.6 0.5X
+Colletion of int with 100 elements, useIterator: false 87 88 1 0.1 8699.0 0.1X
+Colletion of string with 1 elements, useIterator: false 7 8 0 1.3 746.5 0.9X
+Colletion of string with 10 elements, useIterator: false 22 22 0 0.5 2192.3 0.3X
+Colletion of string with 100 elements, useIterator: false 161 162 2 0.1 16091.2 0.0X
+Colletion of Array[int] with 1 elements, useIterator: false 7 8 0 1.4 719.6 0.9X
+Colletion of Array[int] with 10 elements, useIterator: false 19 19 0 0.5 1869.6 0.3X
+Colletion of Array[int] with 100 elements, useIterator: false 138 139 1 0.1 13766.1 0.0X
diff --git a/core/benchmarks/KryoIteratorBenchmark-results.txt b/core/benchmarks/KryoIteratorBenchmark-results.txt
index 546409e81695f..6caa842d0e4d2 100644
--- a/core/benchmarks/KryoIteratorBenchmark-results.txt
+++ b/core/benchmarks/KryoIteratorBenchmark-results.txt
@@ -2,27 +2,27 @@
Benchmark of kryo asIterator on deserialization stream
================================================================================================
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
AMD EPYC 7763 64-Core Processor
Benchmark of kryo asIterator on deserialization stream: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------------------------------------------------------
-Colletion of int with 1 elements, useIterator: true 6 7 0 1.6 641.0 1.0X
-Colletion of int with 10 elements, useIterator: true 13 14 0 0.7 1334.5 0.5X
-Colletion of int with 100 elements, useIterator: true 81 82 0 0.1 8144.5 0.1X
-Colletion of string with 1 elements, useIterator: true 8 8 0 1.3 760.7 0.8X
-Colletion of string with 10 elements, useIterator: true 22 23 0 0.4 2227.9 0.3X
-Colletion of string with 100 elements, useIterator: true 165 165 1 0.1 16461.8 0.0X
-Colletion of Array[int] with 1 elements, useIterator: true 7 8 0 1.3 747.8 0.9X
-Colletion of Array[int] with 10 elements, useIterator: true 20 20 0 0.5 1980.0 0.3X
-Colletion of Array[int] with 100 elements, useIterator: true 149 151 1 0.1 14910.6 0.0X
-Colletion of int with 1 elements, useIterator: false 6 7 0 1.6 618.9 1.0X
-Colletion of int with 10 elements, useIterator: false 13 14 0 0.8 1329.3 0.5X
-Colletion of int with 100 elements, useIterator: false 82 83 1 0.1 8176.8 0.1X
-Colletion of string with 1 elements, useIterator: false 7 8 0 1.3 743.4 0.9X
-Colletion of string with 10 elements, useIterator: false 21 22 0 0.5 2137.1 0.3X
-Colletion of string with 100 elements, useIterator: false 161 162 1 0.1 16131.4 0.0X
-Colletion of Array[int] with 1 elements, useIterator: false 7 8 0 1.4 713.0 0.9X
-Colletion of Array[int] with 10 elements, useIterator: false 19 20 0 0.5 1910.7 0.3X
-Colletion of Array[int] with 100 elements, useIterator: false 140 142 1 0.1 14021.4 0.0X
+Colletion of int with 1 elements, useIterator: true 6 6 0 1.6 629.3 1.0X
+Colletion of int with 10 elements, useIterator: true 14 14 0 0.7 1350.0 0.5X
+Colletion of int with 100 elements, useIterator: true 83 83 1 0.1 8255.8 0.1X
+Colletion of string with 1 elements, useIterator: true 8 8 0 1.3 750.8 0.8X
+Colletion of string with 10 elements, useIterator: true 21 21 1 0.5 2116.4 0.3X
+Colletion of string with 100 elements, useIterator: true 162 163 1 0.1 16191.1 0.0X
+Colletion of Array[int] with 1 elements, useIterator: true 7 8 0 1.4 732.3 0.9X
+Colletion of Array[int] with 10 elements, useIterator: true 19 19 0 0.5 1906.4 0.3X
+Colletion of Array[int] with 100 elements, useIterator: true 142 143 0 0.1 14222.9 0.0X
+Colletion of int with 1 elements, useIterator: false 6 6 0 1.7 604.6 1.0X
+Colletion of int with 10 elements, useIterator: false 13 13 0 0.8 1325.8 0.5X
+Colletion of int with 100 elements, useIterator: false 83 83 0 0.1 8261.2 0.1X
+Colletion of string with 1 elements, useIterator: false 7 8 1 1.4 719.2 0.9X
+Colletion of string with 10 elements, useIterator: false 22 22 0 0.5 2203.0 0.3X
+Colletion of string with 100 elements, useIterator: false 163 163 1 0.1 16294.8 0.0X
+Colletion of Array[int] with 1 elements, useIterator: false 7 7 0 1.5 674.8 0.9X
+Colletion of Array[int] with 10 elements, useIterator: false 18 19 0 0.6 1808.7 0.3X
+Colletion of Array[int] with 100 elements, useIterator: false 135 135 0 0.1 13481.7 0.0X
diff --git a/core/benchmarks/KryoSerializerBenchmark-jdk21-results.txt b/core/benchmarks/KryoSerializerBenchmark-jdk21-results.txt
index 7f61d50f98059..e2c03b0acd23c 100644
--- a/core/benchmarks/KryoSerializerBenchmark-jdk21-results.txt
+++ b/core/benchmarks/KryoSerializerBenchmark-jdk21-results.txt
@@ -2,11 +2,11 @@
Benchmark KryoPool vs old"pool of 1" implementation
================================================================================================
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
AMD EPYC 7763 64-Core Processor
Benchmark KryoPool vs old"pool of 1" implementation: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-----------------------------------------------------------------------------------------------------------------------------------
-KryoPool:true 3586 4906 1321 0.0 7171228.8 1.0X
-KryoPool:false 5722 7623 1380 0.0 11444024.2 0.6X
+KryoPool:true 3610 5082 1510 0.0 7219633.5 1.0X
+KryoPool:false 5886 7699 1454 0.0 11772046.4 0.6X
diff --git a/core/benchmarks/KryoSerializerBenchmark-results.txt b/core/benchmarks/KryoSerializerBenchmark-results.txt
index deef4b3b983e5..54b21654ce054 100644
--- a/core/benchmarks/KryoSerializerBenchmark-results.txt
+++ b/core/benchmarks/KryoSerializerBenchmark-results.txt
@@ -2,11 +2,11 @@
Benchmark KryoPool vs old"pool of 1" implementation
================================================================================================
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
AMD EPYC 7763 64-Core Processor
Benchmark KryoPool vs old"pool of 1" implementation: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-----------------------------------------------------------------------------------------------------------------------------------
-KryoPool:true 3517 5031 1861 0.0 7033673.5 1.0X
-KryoPool:false 5802 7614 1293 0.0 11604325.7 0.6X
+KryoPool:true 3682 5317 1787 0.0 7363182.2 1.0X
+KryoPool:false 5853 7922 1339 0.0 11705848.9 0.6X
diff --git a/core/pom.xml b/core/pom.xml
index df009bc28bca7..79aa783cf2091 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -59,6 +59,14 @@
com.twitterchill-java
+
+ com.esotericsoftware
+ kryo-shaded
+
+
+ org.objenesis
+ objenesis
+ com.github.jnrjnr-posix
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index bd9f58bf7415f..e98554db22524 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -278,15 +278,15 @@ private long trySpillAndAcquire(
}
} catch (ClosedByInterruptException | InterruptedIOException e) {
// This called by user to kill a task (e.g: speculative task).
- logger.error("error while calling spill() on {}", e,
+ logger.error("Error while calling spill() on {}", e,
MDC.of(LogKeys.MEMORY_CONSUMER$.MODULE$, consumerToSpill));
throw new RuntimeException(e.getMessage());
} catch (IOException e) {
- logger.error("error while calling spill() on {}", e,
+ logger.error("Error while calling spill() on {}", e,
MDC.of(LogKeys.MEMORY_CONSUMER$.MODULE$, consumerToSpill));
// checkstyle.off: RegexpSinglelineJava
throw new SparkOutOfMemoryError(
- "_LEGACY_ERROR_TEMP_3300",
+ "SPILL_OUT_OF_MEMORY",
new HashMap() {{
put("consumerToSpill", consumerToSpill.toString());
put("message", e.getMessage());
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index ff00f8c1c7517..65aa7c815fc42 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -216,7 +216,7 @@ public void expandPointerArray(LongArray newArray) {
if (array != null) {
if (newArray.size() < array.size()) {
// checkstyle.off: RegexpSinglelineJava
- throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3301", new HashMap<>());
+ throw new SparkOutOfMemoryError("POINTER_ARRAY_OUT_OF_MEMORY", new HashMap<>());
// checkstyle.on: RegexpSinglelineJava
}
Platform.copyMemory(
diff --git a/core/src/main/resources/org/apache/spark/ui/static/environmentpage.js b/core/src/main/resources/org/apache/spark/ui/static/environmentpage.js
new file mode 100644
index 0000000000000..0bced49a52fe6
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/environmentpage.js
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* global $ */
+
+$(document).ready(function(){
+ $('th').on('click', function(e) {
+ let inputBox = $(this).find('.env-table-filter-input');
+ if (inputBox.length === 0) {
+ $('')
+ .appendTo(this)
+ .focus();
+ } else {
+ inputBox.toggleClass('d-none');
+ inputBox.focus();
+ }
+ e.stopPropagation();
+ });
+
+ $(document).on('click', function() {
+ $('.env-table-filter-input').toggleClass('d-none', true);
+ });
+
+ $(document).on('input', '.env-table-filter-input', function() {
+ const table = $(this).closest('table');
+ const filters = table.find('.env-table-filter-input').map(function() {
+ const columnIdx = $(this).closest('th').index();
+ const searchString = $(this).val().toLowerCase();
+ return { columnIdx, searchString };
+ }).get();
+
+ table.find('tbody tr').each(function() {
+ let showRow = true;
+ for (const filter of filters) {
+ const cellText = $(this).find('td').eq(filter.columnIdx).text().toLowerCase();
+ if (filter.searchString && cellText.indexOf(filter.searchString) === -1) {
+ showRow = false;
+ break;
+ }
+ }
+ if (showRow) {
+ $(this).show();
+ } else {
+ $(this).hide();
+ }
+ });
+ });
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/scroll-button.js b/core/src/main/resources/org/apache/spark/ui/static/scroll-button.js
new file mode 100644
index 0000000000000..3b6bd7c57eb9b
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/scroll-button.js
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+export { addScrollButton };
+
+function createBtn(top = true) {
+ const button = document.createElement('div');
+ button.classList.add('scroll-btn-half');
+ const className = top ? 'scroll-btn-top' : 'scroll-btn-bottom';
+ button.classList.add(className);
+ button.addEventListener('click', function () {
+ window.scrollTo({
+ top: top ? 0 : document.body.scrollHeight,
+ behavior: 'smooth' });
+ })
+ return button;
+}
+
+function addScrollButton() {
+ const containerClass = 'scroll-btn-container';
+ if (document.querySelector(`.${containerClass}`)) {
+ return;
+ }
+ const container = document.createElement('div');
+ container.className = containerClass;
+ container.appendChild(createBtn());
+ container.appendChild(createBtn(false));
+ document.body.appendChild(container);
+}
+
+document.addEventListener('DOMContentLoaded', function () {
+ addScrollButton();
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js
index 839746762f4d2..b3aa85f64c5d3 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/table.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/table.js
@@ -82,31 +82,26 @@ function onMouseOverAndOut(threadId) {
}
function onSearchStringChange() {
- var searchString = $('#search').val().toLowerCase();
+ const searchString = $('#search').val().toLowerCase();
//remove the stacktrace
collapseAllThreadStackTrace(false);
- if (searchString.length == 0) {
- $('tr').each(function() {
+ $('tr[id^="thread_"]').each(function() {
+ if (searchString.length === 0) {
$(this).removeClass('d-none')
- })
- } else {
- $('tr').each(function(){
- if($(this).attr('id') && $(this).attr('id').match(/thread_[0-9]+_tr/) ) {
- var children = $(this).children();
- var found = false;
- for (var i = 0; i < children.length; i++) {
- if (children.eq(i).text().toLowerCase().indexOf(searchString) >= 0) {
- found = true;
- }
- }
- if (found) {
- $(this).removeClass('d-none')
+ } else {
+ let found = false;
+ const children = $(this).children();
+ let i = 0;
+ while(!found && i < children.length) {
+ if (children.eq(i).text().toLowerCase().indexOf(searchString) >= 0) {
+ found = true;
} else {
- $(this).addClass('d-none')
+ i++;
}
}
- });
- }
+ $(this).toggleClass('d-none', !found);
+ }
+ });
}
/* eslint-enable no-unused-vars */
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index bf9b230446b26..52089153b8296 100755
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -446,3 +446,55 @@ a.downloadbutton {
#active-executors-table th:first-child {
border-left: 1px solid #dddddd;
}
+
+.scroll-btn-container {
+ position: fixed;
+ bottom: 20px;
+ right: 20px;
+ width: 48px;
+ height: 80px;
+ border-radius: 8px;
+ background-color: rgba(0, 136, 204, 0.5);;
+ display: flex;
+ flex-direction: column;
+ overflow: hidden;
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
+}
+
+.scroll-btn-half {
+ flex: 1;
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ cursor: pointer;
+ transition: background-color 0.3s ease;
+ position: relative;
+}
+
+.scroll-btn-half:hover {
+ background-color: rgba(0, 136, 204, 1);
+}
+
+.scroll-btn-top::before {
+ content: '';
+ width: 0;
+ height: 0;
+ border-left: 10px solid transparent;
+ border-right: 10px solid transparent;
+ border-bottom: 15px solid white;
+ position: absolute;
+ top: 50%;
+ transform: translateY(-50%);
+}
+
+.scroll-btn-bottom::before {
+ content: '';
+ width: 0;
+ height: 0;
+ border-left: 10px solid transparent;
+ border-right: 10px solid transparent;
+ border-top: 15px solid white;
+ position: absolute;
+ bottom: 50%;
+ transform: translateY(50%);
+}
diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
index adce6c3f5ffdb..3f95515c04d29 100644
--- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
@@ -17,11 +17,12 @@
package org.apache.spark
-import java.util.TimerTask
-import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.{ArrayList, Collections, TimerTask}
+import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit}
import java.util.function.Consumer
import scala.collection.mutable.{ArrayBuffer, HashSet}
+import scala.jdk.CollectionConverters._
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
@@ -53,8 +54,9 @@ private[spark] class BarrierCoordinator(
// TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to
// fetch result, we shall fix the issue.
- private lazy val timer = ThreadUtils.newSingleThreadScheduledExecutor(
+ private lazy val timer = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
"BarrierCoordinator barrier epoch increment timer")
+ private val timerFutures = Collections.synchronizedList(new ArrayList[ScheduledFuture[_]])
// Listen to StageCompleted event, clear corresponding ContextBarrierState.
private val listener = new SparkListener {
@@ -80,8 +82,10 @@ private[spark] class BarrierCoordinator(
states.forEachValue(1, clearStateConsumer)
states.clear()
listenerBus.removeListener(listener)
- ThreadUtils.shutdown(timer)
} finally {
+ timerFutures.asScala.foreach(_.cancel(false))
+ timerFutures.clear()
+ ThreadUtils.shutdown(timer)
super.onStop()
}
}
@@ -134,11 +138,8 @@ private[spark] class BarrierCoordinator(
// Cancel the current active TimerTask and release resources.
private def cancelTimerTask(): Unit = {
- if (timerTask != null) {
- timerTask.cancel()
- timer.purge()
- timerTask = null
- }
+ timerFutures.asScala.foreach(_.cancel(false))
+ timerFutures.clear()
}
// Process the global sync request. The barrier() call succeed if collected enough requests
@@ -173,7 +174,8 @@ private[spark] class BarrierCoordinator(
// we may timeout for the sync.
if (requesters.isEmpty) {
initTimerTask(this)
- timer.schedule(timerTask, timeoutInSecs, TimeUnit.SECONDS)
+ val timerFuture = timer.schedule(timerTask, timeoutInSecs, TimeUnit.SECONDS)
+ timerFutures.add(timerFuture)
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
index c38d552a27aa5..47f287293974f 100644
--- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -82,7 +82,7 @@ class BarrierTaskContext private[spark] (
}
}
// Log the update of global sync every 1 minute.
- timer.scheduleAtFixedRate(timerTask, 1, 1, TimeUnit.MINUTES)
+ val timerFuture = timer.scheduleAtFixedRate(timerTask, 1, 1, TimeUnit.MINUTES)
try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
@@ -121,7 +121,7 @@ class BarrierTaskContext private[spark] (
logProgressInfo(log"failed to perform global sync", Some(startTime))
throw e
} finally {
- timerTask.cancel()
+ timerFuture.cancel(false)
timer.purge()
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d643983ef5dfe..8d1871cf04d67 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.api.python
import java.io._
import java.net._
+import java.nio.channels.{Channels, SocketChannel}
import java.nio.charset.StandardCharsets
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
@@ -38,8 +39,9 @@ import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.{Logging, MDC}
-import org.apache.spark.internal.LogKeys.{HOST, PORT}
+import org.apache.spark.internal.LogKeys.{HOST, PORT, SOCKET_ADDRESS}
import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer}
@@ -231,9 +233,9 @@ private[spark] object PythonRDD extends Logging {
* server object that can be used to join the JVM serving thread in Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
- val handleFunc = (sock: Socket) => {
- val out = new DataOutputStream(sock.getOutputStream)
- val in = new DataInputStream(sock.getInputStream)
+ val handleFunc = (sock: SocketChannel) => {
+ val out = new DataOutputStream(Channels.newOutputStream(sock))
+ val in = new DataInputStream(Channels.newInputStream(sock))
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
@@ -287,7 +289,7 @@ private[spark] object PythonRDD extends Logging {
}
val server = new SocketFuncServer(authHelper, "serve toLocalIterator", handleFunc)
- Array(server.port, server.secret, server)
+ Array(server.connInfo, server.secret, server)
}
def readRDDFromFile(
@@ -716,35 +718,51 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
* collects a list of pickled strings that we pass to Python through a socket.
*/
private[spark] class PythonAccumulatorV2(
- @transient private val serverHost: String,
- private val serverPort: Int,
- private val secretToken: String)
+ @transient private val serverHost: Option[String],
+ private val serverPort: Option[Int],
+ private val secretToken: Option[String],
+ @transient private val socketPath: Option[String])
extends CollectionAccumulator[Array[Byte]] with Logging {
- Utils.checkHost(serverHost)
+ // Unix domain socket
+ def this(socketPath: String) = this(None, None, None, Some(socketPath))
+ // TPC socket
+ def this(serverHost: String, serverPort: Int, secretToken: String) = this(
+ Some(serverHost), Some(serverPort), Some(secretToken), None)
+
+ serverHost.foreach(Utils.checkHost)
val bufferSize = SparkEnv.get.conf.get(BUFFER_SIZE)
+ val isUnixDomainSock = SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
/**
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
* by the DAGScheduler's single-threaded RpcEndpoint anyway.
*/
- @transient private var socket: Socket = _
-
- private def openSocket(): Socket = synchronized {
- if (socket == null || socket.isClosed) {
- socket = new Socket(serverHost, serverPort)
- logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost)}" +
- log" port: ${MDC(PORT, serverPort)}")
+ @transient private var socket: SocketChannel = _
+
+ private def openSocket(): SocketChannel = synchronized {
+ if (socket == null || !socket.isOpen) {
+ if (isUnixDomainSock) {
+ socket = SocketChannel.open(UnixDomainSocketAddress.of(socketPath.get))
+ logInfo(
+ log"Connected to AccumulatorServer at socket: ${MDC(SOCKET_ADDRESS, socketPath.get)}")
+ } else {
+ socket = SocketChannel.open(new InetSocketAddress(serverHost.get, serverPort.get))
+ logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost.get)}" +
+ log" port: ${MDC(PORT, serverPort.get)}")
+ }
// send the secret just for the initial authentication when opening a new connection
- socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
+ secretToken.foreach { token =>
+ Channels.newOutputStream(socket).write(token.getBytes(StandardCharsets.UTF_8))
+ }
}
socket
}
// Need to override so the types match with PythonFunction
override def copyAndReset(): PythonAccumulatorV2 = {
- new PythonAccumulatorV2(serverHost, serverPort, secretToken)
+ new PythonAccumulatorV2(serverHost, serverPort, secretToken, socketPath)
}
override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized {
@@ -757,8 +775,9 @@ private[spark] class PythonAccumulatorV2(
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = openSocket()
- val in = socket.getInputStream
- val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
+ val in = Channels.newInputStream(socket)
+ val out = new DataOutputStream(
+ new BufferedOutputStream(Channels.newOutputStream(socket), bufferSize))
val values = other.value
out.writeInt(values.size)
for (array <- values.asScala) {
@@ -831,21 +850,21 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
def setupEncryptionServer(): Array[Any] = {
encryptionServer = new SocketAuthServer[Unit]("broadcast-encrypt-server") {
- override def handleConnection(sock: Socket): Unit = {
+ override def handleConnection(sock: SocketChannel): Unit = {
val env = SparkEnv.get
- val in = sock.getInputStream()
+ val in = Channels.newInputStream(sock)
val abspath = new File(path).getAbsolutePath
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath))
DechunkedInputStream.dechunkAndCopyToOutput(in, out)
}
}
- Array(encryptionServer.port, encryptionServer.secret)
+ Array(encryptionServer.connInfo, encryptionServer.secret)
}
def setupDecryptionServer(): Array[Any] = {
decryptionServer = new SocketAuthServer[Unit]("broadcast-decrypt-server-for-driver") {
- override def handleConnection(sock: Socket): Unit = {
- val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream()))
+ override def handleConnection(sock: SocketChannel): Unit = {
+ val out = new DataOutputStream(new BufferedOutputStream(Channels.newOutputStream(sock)))
Utils.tryWithSafeFinally {
val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path))
Utils.tryWithSafeFinally {
@@ -859,7 +878,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
}
}
- Array(decryptionServer.port, decryptionServer.secret)
+ Array(decryptionServer.connInfo, decryptionServer.secret)
}
def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult()
@@ -945,8 +964,8 @@ private[spark] class EncryptedPythonBroadcastServer(
val idsAndFiles: Seq[(Long, String)])
extends SocketAuthServer[Unit]("broadcast-decrypt-server") with Logging {
- override def handleConnection(socket: Socket): Unit = {
- val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
+ override def handleConnection(socket: SocketChannel): Unit = {
+ val out = new DataOutputStream(new BufferedOutputStream(Channels.newOutputStream(socket)))
var socketIn: InputStream = null
// send the broadcast id, then the decrypted data. We don't need to send the length, the
// the python pickle module just needs a stream.
@@ -962,7 +981,7 @@ private[spark] class EncryptedPythonBroadcastServer(
}
logTrace("waiting for python to accept broadcast data over socket")
out.flush()
- socketIn = socket.getInputStream()
+ socketIn = Channels.newInputStream(socket)
socketIn.read()
logTrace("done serving broadcast data")
} {
@@ -983,8 +1002,8 @@ private[spark] class EncryptedPythonBroadcastServer(
private[spark] abstract class PythonRDDServer
extends SocketAuthServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
- def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
- val in = sock.getInputStream()
+ def handleConnection(sock: SocketChannel): JavaRDD[Array[Byte]] = {
+ val in = Channels.newInputStream(sock)
val dechunkedInput: InputStream = new DechunkedInputStream(in)
streamToRDD(dechunkedInput)
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 9b107cf7a3bdc..043734aff71f5 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -20,9 +20,9 @@ package org.apache.spark.api.python
import java.io._
import java.net._
import java.nio.ByteBuffer
-import java.nio.channels.SelectionKey
-import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.channels.{AsynchronousCloseException, Channels, SelectionKey, ServerSocketChannel, SocketChannel}
import java.nio.file.{Files => JavaFiles, Path}
+import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
@@ -64,6 +64,8 @@ private[spark] object PythonEvalType {
val SQL_COGROUPED_MAP_ARROW_UDF = 210
val SQL_TRANSFORM_WITH_STATE_PANDAS_UDF = 211
val SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF = 212
+ val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF = 213
+ val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214
val SQL_TABLE_UDF = 300
val SQL_ARROW_TABLE_UDF = 301
@@ -88,6 +90,9 @@ private[spark] object PythonEvalType {
case SQL_TRANSFORM_WITH_STATE_PANDAS_UDF => "SQL_TRANSFORM_WITH_STATE_PANDAS_UDF"
case SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF =>
"SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF"
+ case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF => "SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF"
+ case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
+ "SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF"
}
}
@@ -201,9 +206,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// Python accumulator is always set in production except in tests. See SPARK-27893
private val maybeAccumulator: Option[PythonAccumulator] = Option(accumulator)
- // Expose a ServerSocket to support method calls via socket from Python side. Only relevant for
- // for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] for details.
- private[spark] var serverSocket: Option[ServerSocket] = None
+ // Expose a ServerSocketChannel to support method calls via socket from Python side.
+ // Only relevant for tasks that are a part of barrier stage, refer
+ // `BarrierTaskContext` for details.
+ private[spark] var serverSocketChannel: Option[ServerSocketChannel] = None
// Authentication helper used when serving method calls via socket from Python side.
private lazy val authHelper = new SocketAuthHelper(conf)
@@ -274,7 +280,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
worker.stop()
} catch {
case e: Exception =>
- logWarning("Failed to stop worker")
+ logWarning(log"Failed to stop worker", e)
}
}
}
@@ -347,6 +353,11 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
def writeNextInputToStream(dataOut: DataOutputStream): Boolean
def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions {
+ val isUnixDomainSock = authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+ lazy val sockPath = new File(
+ authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+ .getOrElse(System.getProperty("java.io.tmpdir")),
+ s".${UUID.randomUUID()}.sock")
try {
// Partition index
dataOut.writeInt(partitionIndex)
@@ -356,27 +367,34 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// Init a ServerSocket to accept method calls from Python side.
val isBarrier = context.isInstanceOf[BarrierTaskContext]
if (isBarrier) {
- serverSocket = Some(new ServerSocket(/* port */ 0,
- /* backlog */ 1,
- InetAddress.getByName("localhost")))
- // A call to accept() for ServerSocket shall block infinitely.
- serverSocket.foreach(_.setSoTimeout(0))
+ if (isUnixDomainSock) {
+ serverSocketChannel = Some(ServerSocketChannel.open(StandardProtocolFamily.UNIX))
+ sockPath.deleteOnExit()
+ serverSocketChannel.get.bind(UnixDomainSocketAddress.of(sockPath.getPath))
+ } else {
+ serverSocketChannel = Some(ServerSocketChannel.open())
+ serverSocketChannel.foreach(_.bind(
+ new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1))
+ // A call to accept() for ServerSocket shall block infinitely.
+ serverSocketChannel.foreach(_.socket().setSoTimeout(0))
+ }
+
new Thread("accept-connections") {
setDaemon(true)
override def run(): Unit = {
- while (!serverSocket.get.isClosed()) {
- var sock: Socket = null
+ while (serverSocketChannel.get.isOpen()) {
+ var sock: SocketChannel = null
try {
- sock = serverSocket.get.accept()
+ sock = serverSocketChannel.get.accept()
// Wait for function call from python side.
- sock.setSoTimeout(10000)
+ if (!isUnixDomainSock) sock.socket().setSoTimeout(10000)
authHelper.authClient(sock)
- val input = new DataInputStream(sock.getInputStream())
+ val input = new DataInputStream(Channels.newInputStream(sock))
val requestMethod = input.readInt()
// The BarrierTaskContext function may wait infinitely, socket shall not timeout
// before the function finishes.
- sock.setSoTimeout(0)
+ if (!isUnixDomainSock) sock.socket().setSoTimeout(0)
requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
barrierAndServe(requestMethod, sock)
@@ -385,13 +403,14 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
barrierAndServe(requestMethod, sock, message)
case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
- sock.getOutputStream))
+ Channels.newOutputStream(sock)))
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
}
} catch {
- case e: SocketException if e.getMessage.contains("Socket closed") =>
- // It is possible that the ServerSocket is not closed, but the native socket
- // has already been closed, we shall catch and silently ignore this case.
+ case _: AsynchronousCloseException =>
+ // Ignore to make less noisy. These will be closed when tasks
+ // are finished by listeners.
+ if (isUnixDomainSock) sockPath.delete()
} finally {
if (sock != null) {
sock.close()
@@ -401,33 +420,35 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
}.start()
}
- val secret = if (isBarrier) {
- authHelper.secret
- } else {
- ""
- }
if (isBarrier) {
// Close ServerSocket on task completion.
- serverSocket.foreach { server =>
- context.addTaskCompletionListener[Unit](_ => server.close())
+ serverSocketChannel.foreach { server =>
+ context.addTaskCompletionListener[Unit] { _ =>
+ server.close()
+ if (isUnixDomainSock) sockPath.delete()
+ }
}
- val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
- if (boundPort == -1) {
- val message = "ServerSocket failed to bind to Java side."
- logError(message)
- throw new SparkException(message)
+ if (isUnixDomainSock) {
+ logDebug(s"Started ServerSocket on with Unix Domain Socket $sockPath.")
+ dataOut.writeBoolean(/* isBarrier = */true)
+ dataOut.writeInt(-1)
+ PythonRDD.writeUTF(sockPath.getPath, dataOut)
+ } else {
+ val boundPort: Int = serverSocketChannel.map(_.socket().getLocalPort).getOrElse(-1)
+ if (boundPort == -1) {
+ val message = "ServerSocket failed to bind to Java side."
+ logError(message)
+ throw new SparkException(message)
+ }
+ logDebug(s"Started ServerSocket on port $boundPort.")
+ dataOut.writeBoolean(/* isBarrier = */true)
+ dataOut.writeInt(boundPort)
+ PythonRDD.writeUTF(authHelper.secret, dataOut)
}
- logDebug(s"Started ServerSocket on port $boundPort.")
- dataOut.writeBoolean(/* isBarrier = */true)
- dataOut.writeInt(boundPort)
} else {
dataOut.writeBoolean(/* isBarrier = */false)
- dataOut.writeInt(0)
}
// Write out the TaskContextInfo
- val secretBytes = secret.getBytes(UTF_8)
- dataOut.writeInt(secretBytes.length)
- dataOut.write(secretBytes, 0, secretBytes.length)
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
@@ -485,12 +506,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
/**
* Gateway to call BarrierTaskContext methods.
*/
- def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = {
+ def barrierAndServe(requestMethod: Int, sock: SocketChannel, message: String = ""): Unit = {
require(
- serverSocket.isDefined,
+ serverSocketChannel.isDefined,
"No available ServerSocket to redirect the BarrierTaskContext method call."
)
- val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ val out = new DataOutputStream(new BufferedOutputStream(Channels.newOutputStream(sock)))
try {
val messages = requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 19a0670769675..64b29585a0d92 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -18,10 +18,11 @@
package org.apache.spark.api.python
import java.io.{DataInputStream, DataOutputStream, EOFException, File, InputStream}
-import java.net.{InetAddress, InetSocketAddress, SocketException}
+import java.net.{InetAddress, InetSocketAddress, SocketException, StandardProtocolFamily, UnixDomainSocketAddress}
import java.net.SocketTimeoutException
import java.nio.channels._
import java.util.Arrays
+import java.util.UUID
import java.util.concurrent.TimeUnit
import javax.annotation.concurrent.GuardedBy
@@ -33,6 +34,7 @@ import org.apache.spark._
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR, PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util.{RedirectThread, Utils}
@@ -97,6 +99,7 @@ private[spark] class PythonWorkerFactory(
}
private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
+ private val isUnixDomainSock = authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
@GuardedBy("self")
private var daemon: Process = null
@@ -106,6 +109,8 @@ private[spark] class PythonWorkerFactory(
@GuardedBy("self")
private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, ProcessHandle]()
@GuardedBy("self")
+ private var daemonSockPath: String = _
+ @GuardedBy("self")
private val idleWorkers = new mutable.Queue[PythonWorker]()
@GuardedBy("self")
private var lastActivityNs = 0L
@@ -152,7 +157,11 @@ private[spark] class PythonWorkerFactory(
private def createThroughDaemon(): (PythonWorker, Option[ProcessHandle]) = {
def createWorker(): (PythonWorker, Option[ProcessHandle]) = {
- val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
+ val socketChannel = if (isUnixDomainSock) {
+ SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath))
+ } else {
+ SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
+ }
// These calls are blocking.
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
if (pid < 0) {
@@ -161,7 +170,7 @@ private[spark] class PythonWorkerFactory(
val processHandle = ProcessHandle.of(pid).orElseThrow(
() => new IllegalStateException("Python daemon failed to launch worker.")
)
- authHelper.authToServer(socketChannel.socket())
+ authHelper.authToServer(socketChannel)
socketChannel.configureBlocking(false)
val worker = PythonWorker(socketChannel)
daemonWorkers.put(worker, processHandle)
@@ -192,9 +201,19 @@ private[spark] class PythonWorkerFactory(
private[spark] def createSimpleWorker(
blockingMode: Boolean): (PythonWorker, Option[ProcessHandle]) = {
var serverSocketChannel: ServerSocketChannel = null
+ lazy val sockPath = new File(
+ authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+ .getOrElse(System.getProperty("java.io.tmpdir")),
+ s".${UUID.randomUUID()}.sock")
try {
- serverSocketChannel = ServerSocketChannel.open()
- serverSocketChannel.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+ if (isUnixDomainSock) {
+ serverSocketChannel = ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+ sockPath.deleteOnExit()
+ serverSocketChannel.bind(UnixDomainSocketAddress.of(sockPath.getPath))
+ } else {
+ serverSocketChannel = ServerSocketChannel.open()
+ serverSocketChannel.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+ }
// Create and start the worker
val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule))
@@ -209,9 +228,14 @@ private[spark] class PythonWorkerFactory(
workerEnv.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
- workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocketChannel.socket().getLocalPort
- .toString)
- workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+ if (isUnixDomainSock) {
+ workerEnv.put("PYTHON_WORKER_FACTORY_SOCK_PATH", sockPath.getPath)
+ workerEnv.put("PYTHON_UNIX_DOMAIN_ENABLED", "True")
+ } else {
+ workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocketChannel.socket().getLocalPort
+ .toString)
+ workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+ }
if (Utils.preferIPv6) {
workerEnv.put("SPARK_PREFER_IPV6", "True")
}
@@ -233,7 +257,7 @@ private[spark] class PythonWorkerFactory(
throw new SocketTimeoutException(
"Timed out while waiting for the Python worker to connect back")
}
- authHelper.authClient(socketChannel.socket())
+ authHelper.authClient(socketChannel)
// TODO: When we drop JDK 8, we can just use workerProcess.pid()
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
if (pid < 0) {
@@ -254,6 +278,7 @@ private[spark] class PythonWorkerFactory(
} finally {
if (serverSocketChannel != null) {
serverSocketChannel.close()
+ if (isUnixDomainSock) sockPath.delete()
}
}
}
@@ -278,7 +303,15 @@ private[spark] class PythonWorkerFactory(
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
workerEnv.put("PYTHONPATH", pythonPath)
- workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+ if (isUnixDomainSock) {
+ workerEnv.put(
+ "PYTHON_WORKER_FACTORY_SOCK_DIR",
+ authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+ .getOrElse(System.getProperty("java.io.tmpdir")))
+ workerEnv.put("PYTHON_UNIX_DOMAIN_ENABLED", "True")
+ } else {
+ workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+ }
if (Utils.preferIPv6) {
workerEnv.put("SPARK_PREFER_IPV6", "True")
}
@@ -288,7 +321,11 @@ private[spark] class PythonWorkerFactory(
val in = new DataInputStream(daemon.getInputStream)
try {
- daemonPort = in.readInt()
+ if (isUnixDomainSock) {
+ daemonSockPath = PythonWorkerUtils.readUTF(in)
+ } else {
+ daemonPort = in.readInt()
+ }
} catch {
case _: EOFException if daemon.isAlive =>
throw SparkCoreErrors.eofExceptionWhileReadPortNumberError(
@@ -301,10 +338,14 @@ private[spark] class PythonWorkerFactory(
// test that the returned port number is within a valid range.
// note: this does not cover the case where the port number
// is arbitrary data but is also coincidentally within range
- if (daemonPort < 1 || daemonPort > 0xffff) {
+ val isMalformedPort = !isUnixDomainSock && (daemonPort < 1 || daemonPort > 0xffff)
+ val isMalformedSockPath = isUnixDomainSock && !new File(daemonSockPath).exists()
+ val errorMsg =
+ if (isUnixDomainSock) daemonSockPath else f"$daemonPort (0x$daemonPort%08x)"
+ if (isMalformedPort || isMalformedSockPath) {
val exceptionMessage = f"""
- |Bad data in $daemonModule's standard output. Invalid port number:
- | $daemonPort (0x$daemonPort%08x)
+ |Bad data in $daemonModule's standard output. Invalid port number/socket path:
+ | $errorMsg
|Python command to execute the daemon was:
| ${command.asScala.mkString(" ")}
|Check that you don't have any unexpected modules or libraries in
@@ -407,6 +448,7 @@ private[spark] class PythonWorkerFactory(
daemon = null
daemonPort = 0
+ daemonSockPath = null
} else {
simpleWorkers.values.foreach(_.destroy())
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
index ae3614445be6e..0a6def051a349 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -117,9 +117,15 @@ private[spark] object PythonWorkerUtils extends Logging {
}
}
val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
- dataOut.writeInt(server.port)
- logTrace(s"broadcast decryption server setup on ${server.port}")
- writeUTF(server.secret, dataOut)
+ server.connInfo match {
+ case portNum: Int =>
+ dataOut.writeInt(portNum)
+ writeUTF(server.secret, dataOut)
+ case sockPath: String =>
+ dataOut.writeInt(-1)
+ writeUTF(sockPath, dataOut)
+ }
+ logTrace(s"broadcast decryption server setup on ${server.connInfo}")
sendBidsToRemove()
idsAndFiles.foreach { case (id, _) =>
// send new broadcast
diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 6f9708def2f2b..7eba574751b46 100644
--- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -18,6 +18,7 @@
package org.apache.spark.api.python
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream}
+import java.nio.channels.Channels
import scala.jdk.CollectionConverters._
@@ -25,7 +26,7 @@ import org.apache.spark.{SparkEnv, SparkPythonException}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{PYTHON_WORKER_MODULE, PYTHON_WORKER_RESPONSE, SESSION_ID}
import org.apache.spark.internal.config.BUFFER_SIZE
-import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
private[spark] object StreamingPythonRunner {
@@ -45,6 +46,7 @@ private[spark] class StreamingPythonRunner(
sessionId: String,
workerModule: String) extends Logging {
private val conf = SparkEnv.get.conf
+ private val isUnixDomainSock = conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
@@ -78,14 +80,20 @@ private[spark] class StreamingPythonRunner(
pythonWorker = Some(worker)
pythonWorkerFactory = Some(workerFactory)
- val socket = pythonWorker.get.channel.socket()
- val stream = new BufferedOutputStream(socket.getOutputStream, bufferSize)
- val dataIn = new DataInputStream(new BufferedInputStream(socket.getInputStream, bufferSize))
+ val socketChannel = pythonWorker.get.channel
+ val stream = new BufferedOutputStream(Channels.newOutputStream(socketChannel), bufferSize)
+ val dataIn = new DataInputStream(
+ new BufferedInputStream(Channels.newInputStream(socketChannel), bufferSize))
val dataOut = new DataOutputStream(stream)
- val originalTimeout = socket.getSoTimeout()
- // Set timeout to 5 minute during initialization config transmission
- socket.setSoTimeout(5 * 60 * 1000)
+ val originalTimeout = if (!isUnixDomainSock) {
+ val timeout = socketChannel.socket().getSoTimeout()
+ // Set timeout to 5 minute during initialization config transmission
+ socketChannel.socket().setSoTimeout(5 * 60 * 1000)
+ Some(timeout)
+ } else {
+ None
+ }
val resFromPython = try {
PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
@@ -111,7 +119,7 @@ private[spark] class StreamingPythonRunner(
// Set timeout back to the original timeout
// Should be infinity by default
- socket.setSoTimeout(originalTimeout)
+ originalTimeout.foreach(v => socketChannel.socket().setSoTimeout(v))
if (resFromPython != 0) {
val errMessage = PythonWorkerUtils.readUTF(dataIn)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
index ac6826a9ec774..5c45986a8f9a0 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.security.SocketAuthHelper
private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) {
+ override val isUnixDomainSock = false
override protected def readUtf8(s: Socket): String = {
SerDe.readString(new DataInputStream(s.getInputStream()))
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index ff6ed9f86b554..3b309e0939700 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -18,7 +18,7 @@
package org.apache.spark.api.r
import java.io.{File, OutputStream}
-import java.net.Socket
+import java.nio.channels.{Channels, SocketChannel}
import java.util.{Map => JMap}
import scala.jdk.CollectionConverters._
@@ -179,8 +179,8 @@ private[spark] class RParallelizeServer(sc: JavaSparkContext, parallelism: Int)
extends SocketAuthServer[JavaRDD[Array[Byte]]](
new RAuthHelper(SparkEnv.get.conf), "sparkr-parallelize-server") {
- override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
- val in = sock.getInputStream()
+ override def handleConnection(sock: SocketChannel): JavaRDD[Array[Byte]] = {
+ val in = Channels.newInputStream(sock)
JavaRDD.readRDDFromInputStream(sc.sc, in, parallelism)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index 57b0647e59fd9..e21c772c00779 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -31,7 +31,6 @@ import org.apache.spark.network.crypto.AuthServerBootstrap
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap}
import org.apache.spark.network.shuffle.ExternalBlockHandler
-import org.apache.spark.network.shuffledb.DBBackend
import org.apache.spark.network.util.TransportConf
import org.apache.spark.util.{ShutdownHookManager, Utils}
@@ -86,11 +85,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
protected def newShuffleBlockHandler(conf: TransportConf): ExternalBlockHandler = {
if (sparkConf.get(config.SHUFFLE_SERVICE_DB_ENABLED) && enabled) {
val shuffleDBName = sparkConf.get(config.SHUFFLE_SERVICE_DB_BACKEND)
- val dbBackend = DBBackend.byName(shuffleDBName)
- logInfo(log"Use ${MDC(SHUFFLE_DB_BACKEND_NAME, dbBackend.name())} as the implementation of " +
+ logInfo(
+ log"Use ${MDC(SHUFFLE_DB_BACKEND_NAME, shuffleDBName.name())} as the implementation of " +
log"${MDC(SHUFFLE_DB_BACKEND_KEY, config.SHUFFLE_SERVICE_DB_BACKEND.key)}")
new ExternalBlockHandler(conf,
- findRegisteredExecutorsDBFile(dbBackend.fileName(registeredExecutorsDB)))
+ findRegisteredExecutorsDBFile(shuffleDBName.fileName(registeredExecutorsDB)))
} else {
new ExternalBlockHandler(conf, null)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/EnvironmentPage.scala
index c05b20d30b983..977f8cfae75ef 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/EnvironmentPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/EnvironmentPage.scala
@@ -142,6 +142,7 @@ private[ui] class EnvironmentPage(
{environmentVariablesTable}
+
UIUtils.basicSparkPage(request, content, "Environment")
}
diff --git a/core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala b/core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala
index 8a790291b4e72..30660a177416d 100644
--- a/core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala
+++ b/core/src/main/scala/org/apache/spark/errors/SparkCoreErrors.scala
@@ -267,19 +267,6 @@ private[spark] object SparkCoreErrors {
new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3026")
}
- def unrecognizedSchedulerModePropertyError(
- schedulerModeProperty: String,
- schedulingModeConf: String): Throwable = {
- new SparkException(
- errorClass = "_LEGACY_ERROR_TEMP_3027",
- messageParameters = Map(
- "schedulerModeProperty" -> schedulerModeProperty,
- "schedulingModeConf" -> schedulingModeConf
- ),
- cause = null
- )
- }
-
def sparkError(errorMsg: String): Throwable = {
new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_3028",
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 50bf0bea87f94..97754d5457bec 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -186,7 +186,8 @@ private[spark] class Executor(
val currentJars = new HashMap[String, Long]
val currentArchives = new HashMap[String, Long]
val urlClassLoader =
- createClassLoader(currentJars, isStubbingEnabledForState(jobArtifactState.uuid))
+ createClassLoader(currentJars, isStubbingEnabledForState(jobArtifactState.uuid),
+ isDefaultState(jobArtifactState.uuid))
val replClassLoader = addReplClassLoaderIfNeeded(
urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid)
new IsolatedSessionState(
@@ -307,7 +308,7 @@ private[spark] class Executor(
"executor-heartbeater",
HEARTBEAT_INTERVAL_MS)
- // must be initialized before running startDriverHeartbeat()
+ // must be initialized before running heartbeater.start()
private val heartbeatReceiverRef =
RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv)
@@ -544,7 +545,8 @@ private[spark] class Executor(
t.metrics.setExecutorRunTime(TimeUnit.NANOSECONDS.toMillis(
// SPARK-32898: it's possible that a task is killed when taskStartTimeNs has the initial
// value(=0) still. In this case, the executorRunTime should be considered as 0.
- if (taskStartTimeNs > 0) System.nanoTime() - taskStartTimeNs else 0))
+ if (taskStartTimeNs > 0) (System.nanoTime() - taskStartTimeNs) * taskDescription.cpus
+ else 0))
t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
})
@@ -701,7 +703,8 @@ private[spark] class Executor(
(taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
// We need to subtract Task.run()'s deserialization time to avoid double-counting
task.metrics.setExecutorRunTime(TimeUnit.NANOSECONDS.toMillis(
- (taskFinishNs - taskStartTimeNs) - task.executorDeserializeTimeNs))
+ (taskFinishNs - taskStartTimeNs) * taskDescription.cpus
+ - task.executorDeserializeTimeNs))
task.metrics.setExecutorCpuTime(
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
@@ -930,21 +933,17 @@ private[spark] class Executor(
}
private def setMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = {
- try {
+ if (Executor.mdcIsSupported) {
mdc.foreach { case (key, value) => MDC.put(key, value) }
// avoid overriding the takName by the user
MDC.put(taskNameMDCKey, taskName)
- } catch {
- case _: NoSuchFieldError => logInfo("MDC is not supported.")
}
}
private def cleanMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = {
- try {
+ if (Executor.mdcIsSupported) {
mdc.foreach { case (key, _) => MDC.remove(key) }
MDC.remove(taskNameMDCKey)
- } catch {
- case _: NoSuchFieldError => logInfo("MDC is not supported.")
}
}
@@ -1072,7 +1071,8 @@ private[spark] class Executor(
*/
private def createClassLoader(
currentJars: HashMap[String, Long],
- useStub: Boolean): MutableURLClassLoader = {
+ useStub: Boolean,
+ isDefaultSession: Boolean): MutableURLClassLoader = {
// Bootstrap the list of jars with the user class path.
val now = System.currentTimeMillis()
userClassPath.foreach { url =>
@@ -1084,10 +1084,12 @@ private[spark] class Executor(
val urls = userClassPath.toArray ++ currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}
- createClassLoader(urls, useStub)
+ createClassLoader(urls, useStub, isDefaultSession)
}
- private def createClassLoader(urls: Array[URL], useStub: Boolean): MutableURLClassLoader = {
+ private def createClassLoader(urls: Array[URL],
+ useStub: Boolean,
+ isDefaultSession: Boolean): MutableURLClassLoader = {
logInfo(
log"Starting executor with user classpath" +
log" (userClassPathFirst =" +
@@ -1096,33 +1098,45 @@ private[spark] class Executor(
)
if (useStub) {
- createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES))
+ createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES), isDefaultSession)
} else {
- createClassLoader(urls)
+ createClassLoader(urls, isDefaultSession)
}
}
- private def createClassLoader(urls: Array[URL]): MutableURLClassLoader = {
+ private def createClassLoader(urls: Array[URL],
+ isDefaultSession: Boolean): MutableURLClassLoader = {
+ // SPARK-51537: The isolated session must *inherit* the classloader from the default session,
+ // which has already included the global JARs specified via --jars. For Spark plugins, we
+ // cannot simply add the plugin JARs to the classpath of the isolated session, as this may
+ // cause the plugin to be reloaded, leading to potential conflicts or unexpected behavior.
+ val loader = if (isDefaultSession) systemLoader else defaultSessionState.replClassLoader
if (userClassPathFirst) {
- new ChildFirstURLClassLoader(urls, systemLoader)
+ new ChildFirstURLClassLoader(urls, loader)
} else {
- new MutableURLClassLoader(urls, systemLoader)
+ new MutableURLClassLoader(urls, loader)
}
}
private def createClassLoaderWithStub(
urls: Array[URL],
- binaryName: Seq[String]): MutableURLClassLoader = {
+ binaryName: Seq[String],
+ isDefaultSession: Boolean): MutableURLClassLoader = {
+ // SPARK-51537: The isolated session must *inherit* the classloader from the default session,
+ // which has already included the global JARs specified via --jars. For Spark plugins, we
+ // cannot simply add the plugin JARs to the classpath of the isolated session, as this may
+ // cause the plugin to be reloaded, leading to potential conflicts or unexpected behavior.
+ val loader = if (isDefaultSession) systemLoader else defaultSessionState.replClassLoader
if (userClassPathFirst) {
// user -> (sys -> stub)
val stubClassLoader =
- StubClassLoader(systemLoader, binaryName)
+ StubClassLoader(loader, binaryName)
new ChildFirstURLClassLoader(urls, stubClassLoader)
} else {
// sys -> user -> stub
val stubClassLoader =
StubClassLoader(null, binaryName)
- new ChildFirstURLClassLoader(urls, stubClassLoader, systemLoader)
+ new ChildFirstURLClassLoader(urls, stubClassLoader, loader)
}
}
@@ -1229,7 +1243,8 @@ private[spark] class Executor(
}
if (renewClassLoader) {
// Recreate the class loader to ensure all classes are updated.
- state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs, useStub = true)
+ state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs,
+ useStub = true, isDefaultState(state.sessionUUID))
state.replClassLoader =
addReplClassLoaderIfNeeded(state.urlClassLoader, state.replClassDirUri, state.sessionUUID)
}
@@ -1299,7 +1314,7 @@ private[spark] class Executor(
}
}
-private[spark] object Executor {
+private[spark] object Executor extends Logging {
// This is reserved for internal use by components that need to read task properties before a
// task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
// used instead.
@@ -1308,6 +1323,21 @@ private[spark] object Executor {
// Used to store executorSource, for local mode only
var executorSourceLocalModeOnly: ExecutorSource = null
+ lazy val mdcIsSupported: Boolean = {
+ try {
+ // This tests if any class initialization error is thrown
+ val testKey = System.nanoTime().toString
+ MDC.put(testKey, "testValue")
+ MDC.remove(testKey)
+
+ true
+ } catch {
+ case t: Throwable =>
+ logInfo("MDC is not supported.", t)
+ false
+ }
+ }
+
/**
* Whether a `Throwable` thrown from a task is a fatal error. We will use this to decide whether
* to kill the executor.
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorClassLoader.scala
index 48d7f150ad9bd..6f8138da6f4fb 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorClassLoader.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorClassLoader.scala
@@ -59,9 +59,6 @@ class ExecutorClassLoader(
val parentLoader = new ParentClassLoader(parent)
- // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
- private[executor] var httpUrlConnectionTimeoutMillis: Int = -1
-
private val fetchFn: (String) => InputStream = uri.getScheme() match {
case "spark" => getClassFileInputStreamFromSparkRPC
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
index 1f827e8dc4491..46d54be92f3d6 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
@@ -70,6 +70,30 @@ private[spark] object Python {
.booleanConf
.createWithDefault(false)
+ val PYTHON_UNIX_DOMAIN_SOCKET_ENABLED = ConfigBuilder("spark.python.unix.domain.socket.enabled")
+ .doc("When set to true, the Python driver uses a Unix domain socket for operations like " +
+ "creating or collecting a DataFrame from local data, using accumulators, and executing " +
+ "Python functions with PySpark such as Python UDFs. This configuration only applies " +
+ "to Spark Classic and Spark Connect server.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(sys.env.get("PYSPARK_UDS_MODE").contains("true"))
+
+ val PYTHON_UNIX_DOMAIN_SOCKET_DIR = ConfigBuilder("spark.python.unix.domain.socket.dir")
+ .doc("When specified, it uses the directory to create Unix domain socket files. " +
+ "Otherwise, it uses the default location of the temporary directory set in " +
+ s"'java.io.tmpdir' property. This is used when ${PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key} " +
+ "is enabled.")
+ .internal()
+ .version("4.1.0")
+ .stringConf
+ // UDS requires the length of path lower than 104 characters. We use UUID (36 characters)
+ // and additional prefix "." (1), postfix ".sock" (5), and the path separator (1).
+ .checkValue(
+ _.length <= (104 - (36 + 1 + 5 + 1)),
+ s"The directory path should be lower than ${(104 - (36 + 1 + 5 + 1))}")
+ .createOptional
+
private val PYTHON_WORKER_IDLE_TIMEOUT_SECONDS_KEY = "spark.python.worker.idleTimeoutSeconds"
private val PYTHON_WORKER_KILL_ON_IDLE_TIMEOUT_KEY = "spark.python.worker.killOnIdleTimeout"
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 3ce374d0477d8..039387cba719f 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -807,10 +807,8 @@ package object config {
.doc("Specifies a disk-based store used in shuffle service local db. " +
"ROCKSDB or LEVELDB (deprecated).")
.version("3.4.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(DBBackend.values.map(_.toString).toSet)
- .createWithDefault(DBBackend.ROCKSDB.name)
+ .enumConf(classOf[DBBackend])
+ .createWithDefault(DBBackend.ROCKSDB)
private[spark] val SHUFFLE_SERVICE_PORT =
ConfigBuilder("spark.shuffle.service.port").version("1.2.0").intConf.createWithDefault(7337)
@@ -2295,9 +2293,8 @@ package object config {
private[spark] val SCHEDULER_MODE =
ConfigBuilder("spark.scheduler.mode")
.version("0.8.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .createWithDefault(SchedulingMode.FIFO.toString)
+ .enumConf(SchedulingMode)
+ .createWithDefault(SchedulingMode.FIFO)
private[spark] val SCHEDULER_REVIVE_INTERVAL =
ConfigBuilder("spark.scheduler.revive.interval")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index 6f64dff3f39d6..bea49fb279ee3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -89,7 +89,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, sc: SparkContext
log"${MDC(LogKeys.FILE_NAME, DEFAULT_SCHEDULER_FILE)}")
Some((is, DEFAULT_SCHEDULER_FILE))
} else {
- val schedulingMode = SchedulingMode.withName(sc.conf.get(SCHEDULER_MODE))
+ val schedulingMode = sc.conf.get(SCHEDULER_MODE)
rootPool.addSchedulable(new Pool(
DEFAULT_POOL_NAME, schedulingMode, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
logInfo(log"Fair scheduler configuration not found, created default pool: " +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 13d2d650fcee0..13018da5bc274 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -181,15 +181,7 @@ private[spark] class TaskSchedulerImpl(
private var schedulableBuilder: SchedulableBuilder = null
// default scheduler is FIFO
- private val schedulingModeConf = conf.get(SCHEDULER_MODE)
- val schedulingMode: SchedulingMode =
- try {
- SchedulingMode.withName(schedulingModeConf)
- } catch {
- case e: java.util.NoSuchElementException =>
- throw SparkCoreErrors.unrecognizedSchedulerModePropertyError(SCHEDULER_MODE_PROPERTY,
- schedulingModeConf)
- }
+ val schedulingMode: SchedulingMode = conf.get(SCHEDULER_MODE)
val rootPool: Pool = new Pool("", schedulingMode, 0, 0)
@@ -780,8 +772,6 @@ private[spark] class TaskSchedulerImpl(
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer): Unit = {
- var failedExecutor: Option[String] = None
- var reason: Option[ExecutorLossReason] = None
synchronized {
try {
Option(taskIdToTaskSetManager.get(tid)) match {
@@ -809,12 +799,6 @@ private[spark] class TaskSchedulerImpl(
case e: Exception => logError("Exception in statusUpdate", e)
}
}
- // Update the DAGScheduler without holding a lock on this, since that can deadlock
- if (failedExecutor.isDefined) {
- assert(reason.isDefined)
- dagScheduler.executorLost(failedExecutor.get, reason.get)
- backend.reviveOffers()
- }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
index f800553c5388b..ecebb97ecfc1d 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -19,9 +19,11 @@ package org.apache.spark.security
import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
+import java.nio.channels.SocketChannel
import java.nio.charset.StandardCharsets.UTF_8
import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR, PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils
@@ -35,6 +37,9 @@ import org.apache.spark.util.Utils
* There's no secrecy, so this relies on the sockets being either local or somehow encrypted.
*/
private[spark] class SocketAuthHelper(val conf: SparkConf) {
+ val isUnixDomainSock: Boolean = conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+ lazy val sockDir: String =
+ conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR).getOrElse(System.getProperty("java.io.tmpdir"))
val secret = Utils.createSecret(conf)
@@ -47,6 +52,11 @@ private[spark] class SocketAuthHelper(val conf: SparkConf) {
* @param s The client socket.
* @throws IllegalArgumentException If authentication fails.
*/
+ def authClient(socket: SocketChannel): Unit = {
+ if (isUnixDomainSock) return
+ authClient(socket.socket())
+ }
+
def authClient(s: Socket): Unit = {
var shouldClose = true
try {
@@ -80,7 +90,9 @@ private[spark] class SocketAuthHelper(val conf: SparkConf) {
* @param s The socket connected to the server.
* @throws IllegalArgumentException If authentication fails.
*/
- def authToServer(s: Socket): Unit = {
+ def authToServer(socket: SocketChannel): Unit = {
+ if (isUnixDomainSock) return
+ val s = socket.socket()
var shouldClose = true
try {
writeUtf8(secret, s)
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
index 9efe2af5fcc8a..b0446a4f2febf 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
@@ -17,8 +17,10 @@
package org.apache.spark.security
-import java.io.{BufferedOutputStream, OutputStream}
-import java.net.{InetAddress, ServerSocket, Socket}
+import java.io.{BufferedOutputStream, File, OutputStream}
+import java.net.{InetAddress, InetSocketAddress, StandardProtocolFamily, UnixDomainSocketAddress}
+import java.nio.channels.{Channels, ServerSocketChannel, SocketChannel}
+import java.util.UUID
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
@@ -46,44 +48,70 @@ private[spark] abstract class SocketAuthServer[T](
def this(threadName: String) = this(SparkEnv.get, threadName)
private val promise = Promise[T]()
+ private val isUnixDomainSock: Boolean = authHelper.isUnixDomainSock
- private def startServer(): (Int, String) = {
+ private def startServer(): (Any, String) = {
logTrace("Creating listening socket")
- val address = InetAddress.getLoopbackAddress()
- val serverSocket = new ServerSocket(0, 1, address)
+ lazy val sockPath = new File(authHelper.sockDir, s".${UUID.randomUUID()}.sock")
+
+ val (serverSocketChannel, address) = if (isUnixDomainSock) {
+ val address = UnixDomainSocketAddress.of(sockPath.getPath)
+ val serverChannel = ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+ sockPath.deleteOnExit()
+ serverChannel.bind(address)
+ (serverChannel, address)
+ } else {
+ val address = InetAddress.getLoopbackAddress()
+ val serverChannel = ServerSocketChannel.open()
+ serverChannel.bind(new InetSocketAddress(address, 0), 1)
+ (serverChannel, address)
+ }
+
// Close the socket if no connection in the configured seconds
val timeout = authHelper.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt
logTrace(s"Setting timeout to $timeout sec")
- serverSocket.setSoTimeout(timeout * 1000)
+ if (!isUnixDomainSock) serverSocketChannel.socket().setSoTimeout(timeout * 1000)
new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
- var sock: Socket = null
+ var sock: SocketChannel = null
try {
- logTrace(s"Waiting for connection on $address with port ${serverSocket.getLocalPort}")
- sock = serverSocket.accept()
- logTrace(s"Connection accepted from address ${sock.getRemoteSocketAddress}")
+ if (isUnixDomainSock) {
+ logTrace(s"Waiting for connection on $address.")
+ } else {
+ logTrace(
+ s"Waiting for connection on $address with port " +
+ s"${serverSocketChannel.socket().getLocalPort}")
+ }
+ sock = serverSocketChannel.accept()
+ logTrace(s"Connection accepted from address ${sock.getRemoteAddress}")
authHelper.authClient(sock)
logTrace("Client authenticated")
promise.complete(Try(handleConnection(sock)))
} finally {
logTrace("Closing server")
- JavaUtils.closeQuietly(serverSocket)
+ JavaUtils.closeQuietly(serverSocketChannel)
JavaUtils.closeQuietly(sock)
+ if (isUnixDomainSock) sockPath.delete()
}
}
}.start()
- (serverSocket.getLocalPort, authHelper.secret)
+ if (isUnixDomainSock) {
+ (sockPath.getPath, null)
+ } else {
+ (serverSocketChannel.socket().getLocalPort, authHelper.secret)
+ }
}
- val (port, secret) = startServer()
+ // connInfo is either a string (for UDS) or a port number (for TCP/IP).
+ val (connInfo, secret) = startServer()
/**
* Handle a connection which has already been authenticated. Any error from this function
* will clean up this connection and the entire server, and get propagated to [[getResult]].
*/
- def handleConnection(sock: Socket): T
+ def handleConnection(sock: SocketChannel): T
/**
* Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If
@@ -108,9 +136,9 @@ private[spark] abstract class SocketAuthServer[T](
private[spark] class SocketFuncServer(
authHelper: SocketAuthHelper,
threadName: String,
- func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, threadName) {
+ func: SocketChannel => Unit) extends SocketAuthServer[Unit](authHelper, threadName) {
- override def handleConnection(sock: Socket): Unit = {
+ override def handleConnection(sock: SocketChannel): Unit = {
func(sock)
}
}
@@ -134,8 +162,8 @@ private[spark] object SocketAuthServer {
def serveToStream(
threadName: String,
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
- val handleFunc = (sock: Socket) => {
- val out = new BufferedOutputStream(sock.getOutputStream())
+ val handleFunc = (sock: SocketChannel) => {
+ val out = new BufferedOutputStream(Channels.newOutputStream(sock))
Utils.tryWithSafeFinally {
writeFunc(out)
} {
@@ -144,6 +172,6 @@ private[spark] object SocketAuthServer {
}
val server = new SocketFuncServer(authHelper, threadName, handleFunc)
- Array(server.port, server.secret, server)
+ Array(server.connInfo, server.secret, server)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 783da1fa4c286..000ba8d79bc02 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -47,6 +47,7 @@ import org.apache.spark.internal.io.FileCommitProtocol._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{BoundedPriorityQueue, ByteBufferInputStream, NextIterator, SerializableConfiguration, SerializableJobConf, Utils}
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -234,6 +235,7 @@ class KryoSerializer(conf: SparkConf)
kryo.register(Utils.classForName("scala.reflect.ClassTag$GenericClassTag"))
kryo.register(classOf[ArrayBuffer[Any]])
kryo.register(classOf[Array[Array[Byte]]])
+ kryo.register(classOf[UTF8String])
// We can't load those class directly in order to avoid unnecessary jar dependencies.
// We load them safely, ignore it if the class not found.
@@ -573,6 +575,7 @@ private[serializer] object KryoSerializer {
"org.apache.spark.sql.catalyst.expressions.BoundReference",
"org.apache.spark.sql.catalyst.expressions.SortOrder",
"[Lorg.apache.spark.sql.catalyst.expressions.SortOrder;",
+ "org.apache.spark.sql.catalyst.expressions.GenericInternalRow",
"org.apache.spark.sql.catalyst.InternalRow",
"org.apache.spark.sql.catalyst.InternalRow$",
"[Lorg.apache.spark.sql.catalyst.InternalRow;",
@@ -607,6 +610,12 @@ private[serializer] object KryoSerializer {
"org.apache.spark.sql.execution.joins.LongHashedRelation",
"org.apache.spark.sql.execution.joins.LongToUnsafeRowMap",
"org.apache.spark.sql.execution.joins.UnsafeHashedRelation",
+ "org.apache.spark.sql.columnar.CachedBatch",
+ "org.apache.spark.sql.columnar.SimpleMetricsCachedBatch",
+ "org.apache.spark.sql.execution.columnar.DefaultCachedBatch",
+ "org.apache.spark.sql.columnar.CachedBatchSerializer",
+ "org.apache.spark.sql.columnar.SimpleMetricsCachedBatchSerializer",
+ "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer",
"org.apache.spark.ml.attribute.Attribute",
"org.apache.spark.ml.attribute.AttributeGroup",
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index b05babdce1699..da08635eca4c5 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -111,7 +111,12 @@ private[spark] object SerializationDebugger extends Logging {
visitExternalizable(e, elem :: stack)
case s: Object with java.io.Serializable =>
- val elem = s"object (class ${s.getClass.getName}, $s)"
+ val str = try {
+ s.toString
+ } catch {
+ case NonFatal(_) => "exception in toString"
+ }
+ val elem = s"object (class ${s.getClass.getName}, $str)"
visitSerializable(s, elem :: stack)
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index fc4e6e771aad7..858db498e83ad 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -418,7 +418,7 @@ class BlockManagerMasterEndpoint(
if (externalShuffleServiceRemoveShuffleEnabled) {
mapOutputTracker.shuffleStatuses.get(shuffleId).foreach { shuffleStatus =>
shuffleStatus.withMapStatuses { mapStatuses =>
- mapStatuses.foreach { mapStatus =>
+ mapStatuses.filter(_ != null).foreach { mapStatus =>
// Check if the executor has been deallocated
if (!blockManagerIdByExecutor.contains(mapStatus.location.executorId)) {
val blocksToDel =
@@ -862,31 +862,50 @@ class BlockManagerMasterEndpoint(
private def getLocationsAndStatus(
blockId: BlockId,
requesterHost: String): Option[BlockLocationsAndStatus] = {
- val locations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty)
- val status = locations.headOption.flatMap { bmId =>
- if (externalShuffleServiceRddFetchEnabled && bmId.port == externalShuffleServicePort) {
- blockStatusByShuffleService.get(bmId).flatMap(m => m.get(blockId))
+ val allLocations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty)
+ val blockStatusWithBlockManagerId: Option[(BlockStatus, BlockManagerId)] =
+ (if (externalShuffleServiceRddFetchEnabled && blockId.isRDD) {
+ // If fetching disk persisted RDD from the external shuffle service is enabled then first
+ // try to find the block in the external shuffle service preferring the one running on
+ // the same host. This search includes blocks stored on already killed executors as well.
+ val hostLocalLocations = allLocations.find { bmId =>
+ bmId.host == requesterHost && bmId.port == externalShuffleServicePort
+ }
+ val location = hostLocalLocations
+ .orElse(allLocations.find(_.port == externalShuffleServicePort))
+ location
+ .flatMap(blockStatusByShuffleService.get(_).flatMap(_.get(blockId)))
+ .zip(location)
} else {
- blockManagerInfo.get(bmId).flatMap(_.getStatus(blockId))
+ // trying to find it in the executors running on the same host and persisted on the disk
+ // Implementation detail: using flatMap on iterators makes the transformation lazy.
+ allLocations.filter(_.host == requesterHost).iterator
+ .flatMap { bmId =>
+ blockManagerInfo.get(bmId).flatMap { blockInfo =>
+ blockInfo.getStatus(blockId).map((_, bmId))
+ }
+ }
+ .find(_._1.storageLevel.useDisk)
+ })
+ .orElse {
+ // if the block cannot be found in the same host as a disk stored block then extend the
+ // search to all active (not killed) executors and to all storage levels
+ val location = allLocations.headOption
+ location.flatMap(blockManagerInfo.get(_)).flatMap(_.getStatus(blockId)).zip(location)
}
- }
-
- if (locations.nonEmpty && status.isDefined) {
- val localDirs = locations.find { loc =>
- // When the external shuffle service running on the same host is found among the block
- // locations then the block must be persisted on the disk. In this case the executorId
- // can be used to access this block even when the original executor is already stopped.
- loc.host == requesterHost &&
- (loc.port == externalShuffleServicePort ||
- blockManagerInfo
- .get(loc)
- .flatMap(_.getStatus(blockId).map(_.storageLevel.useDisk))
- .getOrElse(false))
- }.flatMap { bmId => Option(executorIdToLocalDirs.getIfPresent(bmId.executorId)) }
- Some(BlockLocationsAndStatus(locations, status.get, localDirs))
- } else {
- None
- }
+ logDebug(s"Identified block: $blockStatusWithBlockManagerId")
+ blockStatusWithBlockManagerId
+ .map { case (blockStatus: BlockStatus, bmId: BlockManagerId) =>
+ if (bmId.host == requesterHost && blockStatus.storageLevel.useDisk) {
+ BlockLocationsAndStatus(
+ allLocations,
+ blockStatus,
+ Option(executorIdToLocalDirs.getIfPresent(bmId.executorId)))
+ } else {
+ BlockLocationsAndStatus(allLocations, blockStatus, None)
+ }
+ }
+ .orElse(None)
}
private def getLocationsMultipleBlockIds(
diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
index fc7a4675429aa..e95eeddbdace3 100644
--- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
@@ -48,7 +48,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
// Schedule a refresh thread to run periodically
private val timer = ThreadUtils.newDaemonSingleThreadScheduledExecutor("refresh progress")
- timer.scheduleAtFixedRate(
+ private val timerFuture = timer.scheduleAtFixedRate(
() => refresh(), firstDelayMSec, updatePeriodMSec, TimeUnit.MILLISECONDS)
/**
@@ -121,5 +121,8 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
* Tear down the timer thread. The timer thread is a GC root, and it retains the entire
* SparkContext if it's not terminated.
*/
- def stop(): Unit = ThreadUtils.shutdown(timer)
+ def stop(): Unit = {
+ timerFuture.cancel(false)
+ ThreadUtils.shutdown(timer)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index fff6ec4f5b170..d66b9d849d55a 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -222,6 +222,7 @@ private[spark] object UIUtils extends Logging {
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index 948acb7112c8e..5999ce02bb43d 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -50,9 +50,7 @@ private[spark] object UIWorkloadGenerator {
val conf = new SparkConf().setMaster(args(0)).setAppName("Spark UI tester")
val schedulingMode = SchedulingMode.withName(args(1))
- if (schedulingMode == SchedulingMode.FAIR) {
- conf.set(SCHEDULER_MODE, "FAIR")
- }
+ conf.set(SCHEDULER_MODE, schedulingMode)
val nJobSet = args(2).toInt
val sc = new SparkContext(conf)
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
index f2718c6bf8d77..df504e4cc8ef2 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
@@ -169,6 +169,7 @@ private[ui] class EnvironmentPage(
{classpathEntriesTable}
+
UIUtils.headerSparkPage(request, "Environment", content, parent)
diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala
index 07ea9720b0b8d..b72fcb53bdb62 100644
--- a/core/src/main/scala/org/apache/spark/util/Distribution.scala
+++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala
@@ -74,7 +74,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va
private[spark] object Distribution {
def apply(data: Iterable[Double]): Option[Distribution] = {
- if (data.size > 0) {
+ if (data.nonEmpty) {
Some(new Distribution(data))
} else {
None
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index e30380f41566a..df809f4fad745 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -376,9 +376,8 @@ private[spark] object JsonProtocol extends JsonUtils {
g.writeNumberField("Task ID", taskId)
g.writeNumberField("Stage ID", stageId)
g.writeNumberField("Stage Attempt ID", stageAttemptId)
- g.writeArrayFieldStart("Accumulator Updates")
- updates.foreach(accumulableInfoToJson(_, g))
- g.writeEndArray()
+ g.writeFieldName("Accumulator Updates")
+ accumulablesToJson(updates, g)
g.writeEndObject()
}
g.writeEndArray()
@@ -496,7 +495,7 @@ private[spark] object JsonProtocol extends JsonUtils {
def accumulablesToJson(
accumulables: Iterable[AccumulableInfo],
g: JsonGenerator,
- includeTaskMetricsAccumulators: Boolean = true): Unit = {
+ includeTaskMetricsAccumulators: Boolean = true): Unit = {
g.writeStartArray()
accumulables
.filterNot { acc =>
@@ -714,11 +713,8 @@ private[spark] object JsonProtocol extends JsonUtils {
reason.foreach(g.writeStringField("Loss Reason", _))
case taskKilled: TaskKilled =>
g.writeStringField("Kill Reason", taskKilled.reason)
- g.writeArrayFieldStart("Accumulator Updates")
- taskKilled.accumUpdates.foreach { info =>
- accumulableInfoToJson(info, g)
- }
- g.writeEndArray()
+ g.writeFieldName("Accumulator Updates")
+ accumulablesToJson(taskKilled.accumUpdates, g)
case _ =>
// no extra fields to write
}
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 7f61b3f0b2c24..e9d14f904db45 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -65,7 +65,7 @@ private[spark] object ThreadUtils {
}
}
- override def isTerminated: Boolean = synchronized {
+ override def isTerminated: Boolean = {
lock.lock()
try {
serviceIsShutdown && runningTasks == 0
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index e5e4bcacc70c3..ea9b742fb2e1b 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -54,6 +54,7 @@ import org.apache.commons.io.IOUtils
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
+import org.apache.hadoop.fs.audit.CommonAuditContext.currentAuditContext
import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec}
import org.apache.hadoop.ipc.{CallerContext => HadoopCallerContext}
import org.apache.hadoop.ipc.CallerContext.{Builder => HadoopCallerContextBuilder}
@@ -3084,7 +3085,7 @@ private[spark] object Utils
entry = in.getNextEntry()
}
in.close() // so that any error in closing does not get ignored
- logInfo(log"Unzipped from ${MDC(PATH, dfsZipFile)}\n\t${MDC(PATHS, files.mkString("\n\t"))}")
+ logDebug(log"Unzipped from ${MDC(PATH, dfsZipFile)}\n\t${MDC(PATHS, files.mkString("\n\t"))}")
} finally {
// Close everything no matter what happened
IOUtils.closeQuietly(in)
@@ -3171,6 +3172,9 @@ private[util] object CallerContext extends Logging {
* specific applications impacting parts of the Hadoop system and potential problems they may be
* creating (e.g. overloading NN). As HDFS mentioned in HDFS-9184, for a given HDFS operation, it's
* very helpful to track which upper level job issues it.
+ * The context information is also set in the audit context for cloud storage
+ * connectors. If supported, this gets marshalled as part of the HTTP Referrer header
+ * or similar field, and so ends up in the store service logs themselves.
*
* @param from who sets up the caller context (TASK, CLIENT, APPMASTER)
*
@@ -3221,11 +3225,15 @@ private[spark] class CallerContext(
/**
* Set up the caller context [[context]] by invoking Hadoop CallerContext API of
- * [[HadoopCallerContext]].
+ * [[HadoopCallerContext]], which is included in IPC calls,
+ * and the Hadoop audit context, which may be included in cloud storage
+ * requests.
*/
def setCurrentContext(): Unit = if (CallerContext.callerContextEnabled) {
val hdfsContext = new HadoopCallerContextBuilder(context).build()
HadoopCallerContext.setCurrent(hdfsContext)
+ // set the audit context for to object stores, with the prefix "spark"
+ currentAuditContext.put("spark", context)
}
}
diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
index f0a63247e64b1..e9a0c405b0d9f 100644
--- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.File
+import java.nio.ByteBuffer
import java.nio.file.Files
import java.nio.file.attribute.PosixFilePermission
@@ -35,8 +36,9 @@ import org.apache.spark.network.TransportContext
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.{ExecutorDiskUtils, ExternalBlockHandler, ExternalBlockStoreClient}
-import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId, StorageLevel}
+import org.apache.spark.storage.{BroadcastBlockId, RDDBlockId, ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId, StorageLevel}
import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.util.io.ChunkedByteBuffer
/**
* This suite creates an external shuffle server and routes all shuffle fetches through it.
@@ -117,12 +119,15 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi
.set(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED, true)
.set(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true)
.set(config.EXECUTOR_REMOVE_DELAY.key, "0s")
+ .set(config.DRIVER_BIND_ADDRESS.key, Utils.localHostName())
sc = new SparkContext("local-cluster[1,1,1024]", "test", confWithRddFetchEnabled)
sc.env.blockManager.externalShuffleServiceEnabled should equal(true)
sc.env.blockManager.blockStoreClient.getClass should equal(classOf[ExternalBlockStoreClient])
try {
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
val rdd = sc.parallelize(0 until 100, 2)
- .map { i => (i, 1) }
+ .map { i => (i, broadcast.value.size) }
.persist(StorageLevel.DISK_ONLY)
rdd.count()
@@ -173,8 +178,59 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi
"external shuffle service port should be contained")
}
+ eventually(timeout(2.seconds), interval(100.milliseconds)) {
+ val locationStatusForLocalHost =
+ sc.env.blockManager.master.getLocationsAndStatus(blockId, Utils.localHostName())
+ assert(locationStatusForLocalHost.isDefined)
+ assert(locationStatusForLocalHost.get.localDirs.isDefined)
+ assert(locationStatusForLocalHost.get.locations.head.executorId == "0")
+ assert(locationStatusForLocalHost.get.locations.head.host == Utils.localHostName())
+ }
+
+ eventually(timeout(2.seconds), interval(100.milliseconds)) {
+ val locationStatusForRemoteHost =
+ sc.env.blockManager.master.getLocationsAndStatus(blockId, "")
+ assert(locationStatusForRemoteHost.isDefined)
+ assert(locationStatusForRemoteHost.get.localDirs.isEmpty)
+ assert(locationStatusForRemoteHost.get.locations.head.executorId == "0")
+ assert(locationStatusForRemoteHost.get.locations.head.host == Utils.localHostName())
+ }
+
assert(sc.env.blockManager.getRemoteValues(blockId).isDefined)
+ eventually(timeout(2.seconds), interval(100.milliseconds)) {
+ val broadcastBlockId = BroadcastBlockId(broadcast.id, "piece0")
+ val locStatusForMemBroadcast =
+ sc.env.blockManager.master.getLocationsAndStatus(broadcastBlockId, Utils.localHostName())
+ assert(locStatusForMemBroadcast.isDefined)
+ assert(locStatusForMemBroadcast.get.localDirs.isEmpty)
+ assert(locStatusForMemBroadcast.get.locations.head.executorId == "driver")
+ assert(locStatusForMemBroadcast.get.locations.head.host == Utils.localHostName())
+ }
+
+ val byteBuffer = ByteBuffer.wrap(Array[Byte](7))
+ val bytes = new ChunkedByteBuffer(Array(byteBuffer))
+ val diskBroadcastId = BroadcastBlockId(Long.MaxValue, "piece0")
+ sc.env.blockManager.putBytes(diskBroadcastId, bytes, StorageLevel.DISK_ONLY,
+ tellMaster = true)
+ eventually(timeout(2.seconds), interval(100.milliseconds)) {
+ val locStatusForDiskBroadcast =
+ sc.env.blockManager.master.getLocationsAndStatus(diskBroadcastId, Utils.localHostName())
+ assert(locStatusForDiskBroadcast.isDefined)
+ assert(locStatusForDiskBroadcast.get.localDirs.isDefined)
+ assert(locStatusForDiskBroadcast.get.locations.head.executorId == "driver")
+ assert(locStatusForDiskBroadcast.get.locations.head.host == Utils.localHostName())
+ }
+
+ eventually(timeout(2.seconds), interval(100.milliseconds)) {
+ val locStatusForDiskBroadcastForFetch =
+ sc.env.blockManager.master.getLocationsAndStatus(diskBroadcastId, "")
+ assert(locStatusForDiskBroadcastForFetch.isDefined)
+ assert(locStatusForDiskBroadcastForFetch.get.localDirs.isEmpty)
+ assert(locStatusForDiskBroadcastForFetch.get.locations.head.executorId == "driver")
+ assert(locStatusForDiskBroadcastForFetch.get.locations.head.host == Utils.localHostName())
+ }
+
// test unpersist: as executors are killed the blocks will be removed via the shuffle service
rdd.unpersist(true)
assert(sc.env.blockManager.getRemoteValues(blockId).isEmpty)
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 65ed2684a5b00..da6b57a0bccb9 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -59,7 +59,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
test("local mode, FIFO scheduler") {
- val conf = new SparkConf().set(SCHEDULER_MODE, "FIFO")
+ val conf = new SparkConf().set(SCHEDULER_MODE.key, "FIFO")
sc = new SparkContext("local[2]", "test", conf)
testCount()
testTake()
@@ -68,7 +68,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
test("local mode, fair scheduler") {
- val conf = new SparkConf().set(SCHEDULER_MODE, "FAIR")
+ val conf = new SparkConf().set(SCHEDULER_MODE.key, "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
conf.set(SCHEDULER_ALLOCATION_FILE, xmlPath)
sc = new SparkContext("local[2]", "test", conf)
@@ -79,7 +79,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
test("cluster mode, FIFO scheduler") {
- val conf = new SparkConf().set(SCHEDULER_MODE, "FIFO")
+ val conf = new SparkConf().set(SCHEDULER_MODE.key, "FIFO")
sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
testCount()
testTake()
@@ -88,7 +88,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
test("cluster mode, fair scheduler") {
- val conf = new SparkConf().set(SCHEDULER_MODE, "FAIR")
+ val conf = new SparkConf().set(SCHEDULER_MODE.key, "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
conf.set(SCHEDULER_ALLOCATION_FILE, xmlPath)
sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 3801da82b5df4..6473f823406c4 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -244,10 +244,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
}
test("add and list jar files") {
- val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar")
+ val testJar = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar")
+ assume(testJar != null)
try {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
- sc.addJar(jarPath.toString)
+ sc.addJar(testJar.toString)
assert(sc.listJars().count(_.contains("TestUDTF.jar")) == 1)
} finally {
sc.stop()
@@ -396,13 +397,15 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
schedulingMode <- Seq("local-mode", "non-local-mode");
method <- Seq("addJar", "addFile")
) {
- val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar").toString
val master = schedulingMode match {
case "local-mode" => "local"
case "non-local-mode" => "local-cluster[1,1,1024]"
}
test(s"$method can be called twice with same file in $schedulingMode (SPARK-16787)") {
+ val testJar = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar")
+ assume(testJar != null)
sc = new SparkContext(master, "test")
+ val jarPath = testJar.toString
method match {
case "addJar" =>
sc.addJar(jarPath)
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index e38efc27b78f9..ca8326918feca 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -371,10 +371,17 @@ abstract class SparkFunSuite
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName,
"Invalid objectName of a query context. Actual:" + actual.toString)
- assert(actual.startIndex() === expected.startIndex,
- "Invalid startIndex of a query context. Actual:" + actual.toString)
- assert(actual.stopIndex() === expected.stopIndex,
- "Invalid stopIndex of a query context. Actual:" + actual.toString)
+ // If startIndex and stopIndex are -1, it means we simply want to check the
+ // fragment of the query context. This should be the case when the fragment is
+ // distinguished within the query text.
+ if (expected.startIndex != -1) {
+ assert(actual.startIndex() === expected.startIndex,
+ "Invalid startIndex of a query context. Actual:" + actual.toString)
+ }
+ if (expected.stopIndex != -1) {
+ assert(actual.stopIndex() === expected.stopIndex,
+ "Invalid stopIndex of a query context. Actual:" + actual.toString)
+ }
assert(actual.fragment() === expected.fragment,
"Invalid fragment of a query context. Actual:" + actual.toString)
} else if (actual.contextType() == QueryContextType.DataFrame) {
@@ -478,6 +485,12 @@ abstract class SparkFunSuite
ExpectedContext("", "", start, stop, fragment)
}
+ // Check the fragment only. This is only used when the fragment is distinguished within
+ // the query text
+ def apply(fragment: String): ExpectedContext = {
+ ExpectedContext("", "", -1, -1, fragment)
+ }
+
def apply(
objectType: String,
objectName: String,
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 88ad5b3a7483f..4efd2870cccb0 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,7 +18,8 @@
package org.apache.spark.api.python
import java.io.{ByteArrayOutputStream, DataOutputStream, File}
-import java.net.{InetAddress, Socket}
+import java.net.{InetAddress, InetSocketAddress}
+import java.nio.channels.SocketChannel
import java.nio.charset.StandardCharsets
import java.util
@@ -33,6 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.rdd.{HadoopRDD, RDD}
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
import org.apache.spark.util.Utils
@@ -76,10 +78,14 @@ class PythonRDDSuite extends SparkFunSuite with LocalSparkContext {
}
test("python server error handling") {
- val authHelper = new SocketAuthHelper(new SparkConf())
+ val conf = new SparkConf()
+ conf.set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString)
+ val authHelper = new SocketAuthHelper(conf)
val errorServer = new ExceptionPythonServer(authHelper)
- val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
- authHelper.authToServer(client)
+ val socketChannel = SocketChannel.open(
+ new InetSocketAddress(InetAddress.getLoopbackAddress(),
+ errorServer.connInfo.asInstanceOf[Int]))
+ authHelper.authToServer(socketChannel)
val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
assert(ex.getCause().getMessage().contains("exception within handleConnection"))
}
@@ -87,7 +93,7 @@ class PythonRDDSuite extends SparkFunSuite with LocalSparkContext {
class ExceptionPythonServer(authHelper: SocketAuthHelper)
extends SocketAuthServer[Unit](authHelper, "error-server") {
- override def handleConnection(sock: Socket): Unit = {
+ override def handleConnection(sock: SocketChannel): Unit = {
throw new Exception("exception within handleConnection")
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index ca81283e073ac..bd34e6f2bba3d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -506,6 +506,8 @@ class SparkSubmitSuite
test("SPARK-47495: Not to add primary resource to jars again" +
" in k8s client mode & driver runs inside a POD") {
+ val testJar = "src/test/resources/TestUDTF.jar"
+ assume(new File(testJar).exists)
val clArgs = Seq(
"--deploy-mode", "client",
"--proxy-user", "test.user",
@@ -514,7 +516,7 @@ class SparkSubmitSuite
"--class", "org.SomeClass",
"--driver-memory", "1g",
"--conf", "spark.kubernetes.submitInDriver=true",
- "--jars", "src/test/resources/TestUDTF.jar",
+ "--jars", testJar,
"/home/jarToIgnore.jar",
"arg1")
val appArgs = new SparkSubmitArguments(clArgs)
@@ -524,6 +526,8 @@ class SparkSubmitSuite
}
test("SPARK-33782: handles k8s files download to current directory") {
+ val testJar = "src/test/resources/TestUDTF.jar"
+ assume(new File(testJar).exists)
val clArgs = Seq(
"--deploy-mode", "client",
"--proxy-user", "test.user",
@@ -537,7 +541,7 @@ class SparkSubmitSuite
"--files", "src/test/resources/test_metrics_config.properties",
"--py-files", "src/test/resources/test_metrics_system.properties",
"--archives", "src/test/resources/log4j2.properties",
- "--jars", "src/test/resources/TestUDTF.jar",
+ "--jars", testJar,
"/home/thejar.jar",
"arg1")
val appArgs = new SparkSubmitArguments(clArgs)
@@ -561,6 +565,8 @@ class SparkSubmitSuite
test("SPARK-47475: Avoid jars download if scheme matches " +
"spark.kubernetes.jars.avoidDownloadSchemes " +
"in k8s client mode & driver runs inside a POD") {
+ val testJar = "src/test/resources/TestUDTF.jar"
+ assume(new File(testJar).exists)
val hadoopConf = new Configuration()
updateConfWithFakeS3Fs(hadoopConf)
withTempDir { tmpDir =>
@@ -579,7 +585,7 @@ class SparkSubmitSuite
"--files", "src/test/resources/test_metrics_config.properties",
"--py-files", "src/test/resources/test_metrics_system.properties",
"--archives", "src/test/resources/log4j2.properties",
- "--jars", s"src/test/resources/TestUDTF.jar,$remoteJarFile",
+ "--jars", s"$testJar,$remoteJarFile",
"/home/jarToIgnore.jar",
"arg1")
val appArgs = new SparkSubmitArguments(clArgs)
diff --git a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala
index 637b459886bcf..6884196655081 100644
--- a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala
@@ -17,10 +17,14 @@
package org.apache.spark.executor
+import java.io.File
+import java.net.URL
+
import scala.util.Properties
-import org.apache.spark.{JobArtifactSet, JobArtifactState, LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
-import org.apache.spark.util.Utils
+import org.apache.spark.{JobArtifactSet, JobArtifactState, LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TestUtils}
+import org.apache.spark.util.{MutableURLClassLoader, Utils}
+
class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext {
@@ -29,21 +33,28 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext {
.take(2)
.mkString(".")
- val jar1 = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar").toString
+ private val jarURL1 = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar")
+ private lazy val jar1 = jarURL1.toString
// package com.example
// object Hello { def test(): Int = 2 }
// case class Hello(x: Int, y: Int)
- val jar2 = Thread.currentThread().getContextClassLoader
- .getResource(s"TestHelloV2_$scalaVersion.jar").toString
+ private val jarURL2 = Thread.currentThread().getContextClassLoader
+ .getResource(s"TestHelloV2_$scalaVersion.jar")
+ private lazy val jar2 = jarURL2.toString
// package com.example
// object Hello { def test(): Int = 3 }
// case class Hello(x: String)
- val jar3 = Thread.currentThread().getContextClassLoader
- .getResource(s"TestHelloV3_$scalaVersion.jar").toString
+ private val jarURL3 = Thread.currentThread().getContextClassLoader
+ .getResource(s"TestHelloV3_$scalaVersion.jar")
+ private lazy val jar3 = jarURL3.toString
test("Executor classloader isolation with JobArtifactSet") {
+ assume(jarURL1 != null)
+ assume(jarURL2 != null)
+ assume(jarURL3 != null)
+
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
sc.addJar(jar1)
sc.addJar(jar2)
@@ -109,4 +120,134 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext {
}
}
}
+
+ test("SPARK-51537 Executor isolation session classloader inherits from " +
+ "default session classloader") {
+ assume(jarURL2 != null)
+ sc = new SparkContext(new SparkConf()
+ .setAppName("test")
+ .setMaster("local")
+ .set("spark.jars", jar2))
+
+ // TestHelloV2's test method returns '2'
+ val artifactSetWithHelloV2 = new JobArtifactSet(
+ Some(JobArtifactState(uuid = "hello2", replClassDirUri = None)),
+ jars = Map.empty,
+ files = Map.empty,
+ archives = Map.empty
+ )
+
+ JobArtifactSet.withActiveJobArtifactState(artifactSetWithHelloV2.state.get) {
+ sc.parallelize(1 to 1).foreach { _ =>
+ val cls = Utils.classForName("com.example.Hello$")
+ val module = cls.getField("MODULE$").get(null)
+ val result = cls.getMethod("test").invoke(module).asInstanceOf[Int]
+ if (result != 2) {
+ throw new RuntimeException("Unexpected result: " + result)
+ }
+ }
+ }
+ }
+
+ test("SPARK-51537 Executor isolation avoids reloading plugin jars") {
+ val tempDir = Utils.createTempDir()
+
+ val testCodeBody =
+ s"""
+ | public static boolean flag = false;
+ |""".stripMargin
+
+ val compiledTestCode = TestUtils.createCompiledClass(
+ "TestFoo",
+ tempDir,
+ "",
+ null,
+ Seq.empty,
+ Seq.empty,
+ testCodeBody)
+
+ // Initialize the static variable flag in TestFoo when loading plugin at the first time.
+ // If the plugin is reloaded, the TestFoo.flag will be set to false by default.
+ val executorPluginCodeBody =
+ s"""
+ |@Override
+ |public void init(
+ | org.apache.spark.api.plugin.PluginContext ctx,
+ | java.util.Map extraConf) {
+ | TestFoo.flag = true;
+ |}
+ """.stripMargin
+
+ val thisClassPath =
+ sys.props("java.class.path").split(File.pathSeparator).map(p => new File(p).toURI.toURL)
+
+ val compiledExecutorPlugin = TestUtils.createCompiledClass(
+ "TestExecutorPlugin",
+ tempDir,
+ "",
+ null,
+ Seq(tempDir.toURI.toURL) ++ thisClassPath,
+ Seq("org.apache.spark.api.plugin.ExecutorPlugin"),
+ executorPluginCodeBody)
+
+ val sparkPluginCodeBody =
+ """
+ |@Override
+ |public org.apache.spark.api.plugin.ExecutorPlugin executorPlugin() {
+ | return new TestExecutorPlugin();
+ |}
+ |
+ |@Override
+ |public org.apache.spark.api.plugin.DriverPlugin driverPlugin() { return null; }
+ """.stripMargin
+
+ val compiledSparkPlugin = TestUtils.createCompiledClass(
+ "TestSparkPlugin",
+ tempDir,
+ "",
+ null,
+ Seq(tempDir.toURI.toURL) ++ thisClassPath,
+ Seq("org.apache.spark.api.plugin.SparkPlugin"),
+ sparkPluginCodeBody)
+
+ val jarUrl = TestUtils.createJar(
+ Seq(compiledSparkPlugin, compiledExecutorPlugin, compiledTestCode),
+ new File(tempDir, "testplugin.jar"))
+
+ def getClassLoader: MutableURLClassLoader = {
+ val loader = new MutableURLClassLoader(new Array[URL](0),
+ Thread.currentThread.getContextClassLoader)
+ Thread.currentThread.setContextClassLoader(loader)
+ loader
+ }
+ // SparkContext does not add plugin jars specified by `spark.jars` configuration
+ // to the classpath, causing ClassNotFoundException when initializing plugins
+ // in SparkContext. We manually add the jars to the ClassLoader to resolve this.
+ val loader = getClassLoader
+ loader.addURL(jarUrl)
+
+ sc = new SparkContext(new SparkConf()
+ .setAppName("avoid-reloading-plugins")
+ .setMaster("local-cluster[1, 1, 1024]")
+ .set("spark.jars", jarUrl.toString)
+ .set("spark.plugins", "TestSparkPlugin"))
+
+ val jobArtifactSet = new JobArtifactSet(
+ Some(JobArtifactState(uuid = "avoid-reloading-plugins", replClassDirUri = None)),
+ jars = Map.empty,
+ files = Map.empty,
+ archives = Map.empty
+ )
+
+ JobArtifactSet.withActiveJobArtifactState(jobArtifactSet.state.get) {
+ sc.parallelize(1 to 1).foreach { _ =>
+ val cls1 = Utils.classForName("TestFoo")
+ val z = cls1.getField("flag").getBoolean(null)
+ // If the plugin has been reloaded, the TestFoo.flag will be false.
+ if (!z) {
+ throw new RuntimeException("The spark plugin is reloaded")
+ }
+ }
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index fa13092dc47aa..6f525cf8b898a 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -55,7 +55,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task,
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockManager, BlockManagerId}
-import org.apache.spark.util.{LongAccumulator, SparkUncaughtExceptionHandler, ThreadUtils, UninterruptibleThread}
+import org.apache.spark.util.{LongAccumulator, SparkUncaughtExceptionHandler, ThreadUtils, UninterruptibleThread, Utils}
class ExecutorSuite extends SparkFunSuite
with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
@@ -81,6 +81,8 @@ class ExecutorSuite extends SparkFunSuite
resources: immutable.Map[String, ResourceInformation]
= immutable.Map.empty[String, ResourceInformation])(f: Executor => Unit): Unit = {
var executor: Executor = null
+ val getCustomHostname = PrivateMethod[Option[String]](Symbol("customHostname"))
+ val defaultCustomHostNameValue = Utils.invokePrivate(getCustomHostname())
try {
executor = new Executor(executorId, executorHostname, env, userClassPath, isLocal,
uncaughtExceptionHandler, resources)
@@ -90,6 +92,10 @@ class ExecutorSuite extends SparkFunSuite
if (executor != null) {
executor.stop()
}
+ // SPARK-51633: Reset the custom hostname to its default value in finally block
+ // to avoid contaminating other tests
+ val setCustomHostname = PrivateMethod[Unit](Symbol("customHostname_$eq"))
+ Utils.invokePrivate(setCustomHostname(defaultCustomHostNameValue))
}
}
diff --git a/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala b/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala
index 573540180e6cb..77d782461a2ec 100644
--- a/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala
@@ -86,7 +86,6 @@ class ProcfsMetricsGetterSuite extends SparkFunSuite {
val child = process.toHandle.pid()
eventually(timeout(10.seconds), interval(100.milliseconds)) {
val pids = p.computeProcessTree()
- assert(pids.size === 3)
assert(pids.contains(currentPid))
assert(pids.contains(child))
}
diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
index ae99735084056..457e92b062808 100644
--- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
+++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
@@ -21,6 +21,7 @@ import java.util.Locale
import java.util.concurrent.TimeUnit
import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.network.shuffledb.DBBackend
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.util.SparkConfWithEnv
@@ -387,4 +388,41 @@ class ConfigEntrySuite extends SparkFunSuite {
ConfigBuilder(testKey("oc5")).onCreate(_ => onCreateCalled = true).fallbackConf(fallback)
assert(onCreateCalled)
}
+
+
+ test("SPARK-51874: Add Scala Enumeration support to ConfigBuilder") {
+ object MyTestEnum extends Enumeration {
+ val X, Y, Z = Value
+ }
+ val conf = new SparkConf()
+ val enumConf = ConfigBuilder("spark.test.enum.key")
+ .enumConf(MyTestEnum)
+ .createWithDefault(MyTestEnum.X)
+ assert(conf.get(enumConf) === MyTestEnum.X)
+ conf.set(enumConf, MyTestEnum.Y)
+ assert(conf.get(enumConf) === MyTestEnum.Y)
+ conf.set(enumConf.key, "Z")
+ assert(conf.get(enumConf) === MyTestEnum.Z)
+ val e = intercept[IllegalArgumentException] {
+ conf.set(enumConf.key, "A")
+ conf.get(enumConf)
+ }
+ assert(e.getMessage === s"${enumConf.key} should be one of X, Y, Z, but was A")
+ }
+
+ test("SPARK-51896: Add Java enum support to ConfigBuilder") {
+ val conf = new SparkConf()
+ val enumConf = ConfigBuilder("spark.test.java.enum.key")
+ .enumConf(classOf[DBBackend])
+ .createWithDefault(DBBackend.LEVELDB)
+ assert(conf.get(enumConf) === DBBackend.LEVELDB)
+ conf.set(enumConf, DBBackend.ROCKSDB)
+ assert(conf.get(enumConf) === DBBackend.ROCKSDB)
+ val e = intercept[IllegalArgumentException] {
+ conf.set(enumConf.key, "ANYDB")
+ conf.get(enumConf)
+ }
+ assert(e.getMessage ===
+ s"${enumConf.key} should be one of ${DBBackend.values.mkString(", ")}, but was ANYDB")
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 7607d4d9fe6d9..be1bc5fe3212a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -1573,6 +1573,8 @@ class TaskSetManagerSuite
}
test("SPARK-21563 context's added jars shouldn't change mid-TaskSet") {
+ val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar")
+ assume(jarPath != null)
sc = new SparkContext("local", "test")
val addedJarsPreTaskSet = Map[String, Long](sc.allAddedJars.toSeq: _*)
assert(addedJarsPreTaskSet.size === 0)
@@ -1588,7 +1590,6 @@ class TaskSetManagerSuite
assert(taskOption2.get.artifacts.jars === addedJarsPreTaskSet)
// even with a jar added mid-TaskSet
- val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar")
sc.addJar(jarPath.toString)
val addedJarsMidTaskSet = Map[String, Long](sc.allAddedJars.toSeq: _*)
assert(addedJarsPreTaskSet !== addedJarsMidTaskSet)
diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
index e57cb701b6284..c5a6199cf4c1d 100644
--- a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
@@ -18,14 +18,17 @@ package org.apache.spark.security
import java.io.Closeable
import java.net._
+import java.nio.channels.SocketChannel
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.config._
+import org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.util.Utils
class SocketAuthHelperSuite extends SparkFunSuite {
private val conf = new SparkConf()
+ conf.set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString)
private val authHelper = new SocketAuthHelper(conf)
test("successful auth") {
@@ -43,7 +46,9 @@ class SocketAuthHelperSuite extends SparkFunSuite {
test("failed auth") {
Utils.tryWithResource(new ServerThread()) { server =>
Utils.tryWithResource(server.createClient()) { client =>
- val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128))
+ val badHelper = new SocketAuthHelper(new SparkConf()
+ .set(AUTH_SECRET_BIT_LENGTH, 128)
+ .set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString))
intercept[IllegalArgumentException] {
badHelper.authToServer(client)
}
@@ -66,8 +71,9 @@ class SocketAuthHelperSuite extends SparkFunSuite {
setDaemon(true)
start()
- def createClient(): Socket = {
- new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort())
+ def createClient(): SocketChannel = {
+ SocketChannel.open(new InetSocketAddress(
+ InetAddress.getLoopbackAddress(), ss.getLocalPort))
}
override def run(): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala
index e903cf31d69f2..e5f46012288ba 100644
--- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala
@@ -21,7 +21,7 @@ import java.io._
import scala.annotation.meta.param
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
class SerializationDebuggerSuite extends SparkFunSuite {
@@ -180,6 +180,15 @@ class SerializationDebuggerSuite extends SparkFunSuite {
assert(e.getMessage.contains("SerializableClass2")) // found debug trace should be present
}
+ test("SPARK-51691 improveException swallow underlying exception") {
+ val e = SerializationDebugger.improveException(
+ new SerializableClassWithStringException(new NotSerializable),
+ new NotSerializableException("someClass"))
+ assert(e.getMessage.contains("exception in toString"))
+ assert(e.getMessage.contains("someClass"))
+ assert(e.getMessage.contains("SerializableClassWithStringException"))
+ }
+
test("improveException with error in debugger") {
// Object that throws exception in the SerializationDebugger
val o = new SerializableClass1 {
@@ -205,6 +214,14 @@ class SerializableClass1 extends Serializable
class SerializableClass2(val objectField: Object) extends Serializable
+class SerializableClassWithStringException(val objectField: Object) extends Serializable {
+ override def toString: String = {
+ // simulate the behavior of TreeNode#toString that SQLConf.get may throw exception
+ throw new SparkRuntimeException(errorClass = "INTERNAL_ERROR",
+ messageParameters = Map("message" -> "this is an internal error"),
+ cause = null)
+ }
+}
class SerializableArray(val arrayField: Array[Object]) extends Serializable
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index ed2a1e7fadfa7..b373e295d5734 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -474,6 +474,26 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
assert(!BlockManagerId("notADriverIdentifier", "XXX", 1).isDriver)
}
+ test("SPARK-43221: Host local block fetching should use a block status with disk size") {
+ conf.set(IO_ENCRYPTION_ENABLED, true)
+ conf.set(SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true)
+ val store1 = makeBlockManager(2000, "exec1")
+ val store2 = makeBlockManager(2000, "exec2")
+ val store3 = makeBlockManager(2000, "exec3")
+ val store4 = makeBlockManager(2000, "exec4")
+ val value = new Array[Byte](100)
+ val broadcastId = BroadcastBlockId(0)
+ store1.putSingle(broadcastId, value, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store2.putSingle(broadcastId, value, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store3.putSingle(broadcastId, value, StorageLevel.DISK_ONLY, tellMaster = true)
+ store4.getRemoteBytes(broadcastId) match {
+ case Some(block) =>
+ assert(block.size > 0, "The block size must be greater than 0 for a nonempty block!")
+ case None =>
+ assert(false, "Block not found!")
+ }
+ }
+
test("master + 1 manager interaction") {
val store = makeBlockManager(20000)
val a1 = new Array[Byte](4000)
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 89e3d8371be4c..a9399edeb9ad7 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -1166,7 +1166,9 @@ private[spark] object JsonProtocolSuite extends Assertions {
assert(taskId1 === taskId2)
assert(stageId1 === stageId2)
assert(stageAttemptId1 === stageAttemptId2)
- assertSeqEquals[AccumulableInfo](updates1, updates2, (a, b) => a.equals(b))
+ val filteredUpdates = updates1
+ .filterNot { acc => acc.name.exists(accumulableExcludeList.contains) }
+ assertSeqEquals[AccumulableInfo](filteredUpdates, updates2, (a, b) => a.equals(b))
})
assertSeqEquals[((Int, Int), ExecutorMetrics)](
e1.executorUpdates.toSeq.sortBy(_._1),
@@ -1299,7 +1301,9 @@ private[spark] object JsonProtocolSuite extends Assertions {
assert(r1.description === r2.description)
assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals)
assert(r1.fullStackTrace === r2.fullStackTrace)
- assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b))
+ val filteredUpdates = r1.accumUpdates
+ .filterNot { acc => acc.name.exists(accumulableExcludeList.contains) }
+ assertSeqEquals[AccumulableInfo](filteredUpdates, r2.accumUpdates, (a, b) => a.equals(b))
case (TaskResultLost, TaskResultLost) =>
case (r1: TaskKilled, r2: TaskKilled) =>
assert(r1.reason == r2.reason)
@@ -2774,28 +2778,6 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Count Failed Values": true
| },
| {
- | "ID": 12,
- | "Name": "$UPDATED_BLOCK_STATUSES",
- | "Update": [
- | {
- | "Block ID": "rdd_0_0",
- | "Status": {
- | "Storage Level": {
- | "Use Disk": true,
- | "Use Memory": true,
- | "Use Off Heap": false,
- | "Deserialized": false,
- | "Replication": 2
- | },
- | "Memory Size": 0,
- | "Disk Size": 0
- | }
- | }
- | ],
- | "Internal": true,
- | "Count Failed Values": true
- | },
- | {
| "ID": 13,
| "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}",
| "Update": 0,
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 3312bd3d5743f..077dd489378fd 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -37,6 +37,7 @@ import org.apache.commons.lang3.SystemUtils
import org.apache.commons.math3.stat.inference.ChiSquareTest
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.audit.CommonAuditContext.currentAuditContext
import org.apache.hadoop.ipc.{CallerContext => HadoopCallerContext}
import org.apache.logging.log4j.Level
@@ -1003,9 +1004,19 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties {
}
test("Set Spark CallerContext") {
- val context = "test"
- new CallerContext(context).setCurrentContext()
- assert(s"SPARK_$context" === HadoopCallerContext.getCurrent.toString)
+ currentAuditContext.reset
+ new CallerContext("test",
+ Some("upstream"),
+ Some("app"),
+ Some("attempt"),
+ Some(1),
+ Some(2),
+ Some(3),
+ Some(4),
+ Some(5)).setCurrentContext()
+ val expected = s"SPARK_test_app_attempt_JId_1_SId_2_3_TId_4_5_upstream"
+ assert(expected === HadoopCallerContext.getCurrent.toString)
+ assert(expected === currentAuditContext.get("spark"))
}
test("encodeFileNameToURIRawPath") {
diff --git a/dev/.rat-excludes b/dev/.rat-excludes
index 06f81e9f0a540..5084d0b6905a3 100644
--- a/dev/.rat-excludes
+++ b/dev/.rat-excludes
@@ -48,6 +48,7 @@ jquery.mustache.js
pyspark-coverage-site/*
cloudpickle/*
join.py
+tblib.py
SparkILoop.scala
sbt
sbt-launch-lib.bash
diff --git a/dev/check-license b/dev/check-license
index bc7f493368a34..9836d8c464aa1 100755
--- a/dev/check-license
+++ b/dev/check-license
@@ -58,7 +58,7 @@ else
declare java_cmd=java
fi
-export RAT_VERSION=0.15
+export RAT_VERSION=0.16.1
export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar
mkdir -p "$FWDIR"/lib
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index 8b0106696ee99..d2a7e2b845f6b 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -137,6 +137,12 @@ if [[ "$1" == "finalize" ]]; then
--repository-url https://upload.pypi.org/legacy/ \
"pyspark_connect-$PYSPARK_VERSION.tar.gz" \
"pyspark_connect-$PYSPARK_VERSION.tar.gz.asc"
+ svn update "pyspark_client-$RELEASE_VERSION.tar.gz"
+ svn update "pyspark_client-$RELEASE_VERSION.tar.gz.asc"
+ twine upload -u __token__ -p $PYPI_API_TOKEN \
+ --repository-url https://upload.pypi.org/legacy/ \
+ "pyspark_client-$RELEASE_VERSION.tar.gz" \
+ "pyspark_client-$RELEASE_VERSION.tar.gz.asc"
cd ..
rm -rf svn-spark
echo "PySpark uploaded"
@@ -330,6 +336,14 @@ if [[ "$1" == "package" ]]; then
--output $PYTHON_CONNECT_DIST_NAME.asc \
--detach-sig $PYTHON_CONNECT_DIST_NAME
shasum -a 512 $PYTHON_CONNECT_DIST_NAME > $PYTHON_CONNECT_DIST_NAME.sha512
+
+ PYTHON_CLIENT_DIST_NAME=pyspark_client-$PYSPARK_VERSION.tar.gz
+ cp spark-$SPARK_VERSION-bin-$NAME/python/dist/$PYTHON_CLIENT_DIST_NAME .
+
+ echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \
+ --output $PYTHON_CLIENT_DIST_NAME.asc \
+ --detach-sig $PYTHON_CLIENT_DIST_NAME
+ shasum -a 512 $PYTHON_CLIENT_DIST_NAME > $PYTHON_CLIENT_DIST_NAME.sha512
fi
echo "Copying and signing regular binary distribution"
@@ -341,7 +355,7 @@ if [[ "$1" == "package" ]]; then
if [[ -n $SPARK_CONNECT_FLAG ]]; then
echo "Copying and signing Spark Connect binary distribution"
- SPARK_CONNECT_DIST_NAME=spark-$SPARK_VERSION-bin-$NAME-spark-connect.tgz
+ SPARK_CONNECT_DIST_NAME=spark-$SPARK_VERSION-bin-$NAME-connect.tgz
cp spark-$SPARK_VERSION-bin-$NAME/$SPARK_CONNECT_DIST_NAME .
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \
--output $SPARK_CONNECT_DIST_NAME.asc \
diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh
index 43c198301b702..9d4ca1fb51503 100755
--- a/dev/create-release/release-tag.sh
+++ b/dev/create-release/release-tag.sh
@@ -73,6 +73,14 @@ cd spark
git config user.name "$GIT_NAME"
git config user.email "$GIT_EMAIL"
+# Remove test jars and classes that do not belong to source releases.
+rm $( dev/test-jars.txt
+rm $( dev/test-classes.txt
+git commit -a -m "Removing test jars and class files"
+JAR_RM_REF=$(git rev-parse HEAD)
+
# Create release version
$MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs
if [[ $RELEASE_VERSION != *"preview"* ]]; then
@@ -91,6 +99,9 @@ git commit -a -m "Preparing Spark release $RELEASE_TAG"
echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH"
git tag $RELEASE_TAG
+# Restore test jars for dev.
+git revert --no-edit $JAR_RM_REF
+
# Create next version
$MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs
# Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 08608b057b8a1..3ca6f9060cd01 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -31,7 +31,7 @@ bcprov-jdk18on/1.80//bcprov-jdk18on-1.80.jar
blas/3.0.3//blas-3.0.3.jar
breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar
breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar
-bundle/2.25.53//bundle-2.25.53.jar
+bundle/2.29.52//bundle-2.29.52.jar
cats-kernel_2.13/2.8.0//cats-kernel_2.13-2.8.0.jar
checker-qual/3.43.0//checker-qual-3.43.0.jar
chill-java/0.10.0//chill-java-0.10.0.jar
@@ -39,17 +39,17 @@ chill_2.13/0.10.0//chill_2.13-0.10.0.jar
commons-cli/1.9.0//commons-cli-1.9.0.jar
commons-codec/1.18.0//commons-codec-1.18.0.jar
commons-collections/3.2.2//commons-collections-3.2.2.jar
-commons-collections4/4.4//commons-collections4-4.4.jar
+commons-collections4/4.5.0//commons-collections4-4.5.0.jar
commons-compiler/3.1.9//commons-compiler-3.1.9.jar
commons-compress/1.27.1//commons-compress-1.27.1.jar
commons-crypto/1.1.0//commons-crypto-1.1.0.jar
commons-dbcp/1.4//commons-dbcp-1.4.jar
-commons-io/2.18.0//commons-io-2.18.0.jar
+commons-io/2.19.0//commons-io-2.19.0.jar
commons-lang/2.6//commons-lang-2.6.jar
commons-lang3/3.17.0//commons-lang3-3.17.0.jar
commons-math3/3.6.1//commons-math3-3.6.1.jar
commons-pool/1.5.4//commons-pool-1.5.4.jar
-commons-text/1.13.0//commons-text-1.13.0.jar
+commons-text/1.13.1//commons-text-1.13.1.jar
compress-lzf/1.1.2//compress-lzf-1.1.2.jar
curator-client/5.7.1//curator-client-5.7.1.jar
curator-framework/5.7.1//curator-framework-5.7.1.jar
@@ -57,7 +57,7 @@ curator-recipes/5.7.1//curator-recipes-5.7.1.jar
datanucleus-api-jdo/4.2.4//datanucleus-api-jdo-4.2.4.jar
datanucleus-core/4.1.17//datanucleus-core-4.1.17.jar
datanucleus-rdbms/4.1.19//datanucleus-rdbms-4.1.19.jar
-datasketches-java/6.1.1//datasketches-java-6.1.1.jar
+datasketches-java/6.2.0//datasketches-java-6.2.0.jar
datasketches-memory/3.0.2//datasketches-memory-3.0.2.jar
derby/10.16.1.1//derby-10.16.1.1.jar
derbyshared/10.16.1.1//derbyshared-10.16.1.1.jar
@@ -129,18 +129,18 @@ javax.servlet-api/4.0.1//javax.servlet-api-4.0.1.jar
javolution/5.5.1//javolution-5.5.1.jar
jaxb-core/4.0.5//jaxb-core-4.0.5.jar
jaxb-runtime/4.0.5//jaxb-runtime-4.0.5.jar
-jcl-over-slf4j/2.0.16//jcl-over-slf4j-2.0.16.jar
+jcl-over-slf4j/2.0.17//jcl-over-slf4j-2.0.17.jar
jdo-api/3.0.1//jdo-api-3.0.1.jar
jdom2/2.0.6//jdom2-2.0.6.jar
-jersey-client/3.0.16//jersey-client-3.0.16.jar
-jersey-common/3.0.16//jersey-common-3.0.16.jar
-jersey-container-servlet-core/3.0.16//jersey-container-servlet-core-3.0.16.jar
-jersey-container-servlet/3.0.16//jersey-container-servlet-3.0.16.jar
-jersey-hk2/3.0.16//jersey-hk2-3.0.16.jar
-jersey-server/3.0.16//jersey-server-3.0.16.jar
+jersey-client/3.0.17//jersey-client-3.0.17.jar
+jersey-common/3.0.17//jersey-common-3.0.17.jar
+jersey-container-servlet-core/3.0.17//jersey-container-servlet-core-3.0.17.jar
+jersey-container-servlet/3.0.17//jersey-container-servlet-3.0.17.jar
+jersey-hk2/3.0.17//jersey-hk2-3.0.17.jar
+jersey-server/3.0.17//jersey-server-3.0.17.jar
jettison/1.5.4//jettison-1.5.4.jar
-jetty-util-ajax/11.0.24//jetty-util-ajax-11.0.24.jar
-jetty-util/11.0.24//jetty-util-11.0.24.jar
+jetty-util-ajax/11.0.25//jetty-util-ajax-11.0.25.jar
+jetty-util/11.0.25//jetty-util-11.0.25.jar
jjwt-api/0.12.6//jjwt-api-0.12.6.jar
jjwt-impl/0.12.6//jjwt-impl-0.12.6.jar
jjwt-jackson/0.12.6//jjwt-jackson-0.12.6.jar
@@ -157,8 +157,8 @@ json4s-jackson_2.13/4.0.7//json4s-jackson_2.13-4.0.7.jar
json4s-scalap_2.13/4.0.7//json4s-scalap_2.13-4.0.7.jar
jsr305/3.0.0//jsr305-3.0.0.jar
jta/1.1//jta-1.1.jar
-jul-to-slf4j/2.0.16//jul-to-slf4j-2.0.16.jar
-kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar
+jul-to-slf4j/2.0.17//jul-to-slf4j-2.0.17.jar
+kryo-shaded/4.0.3//kryo-shaded-4.0.3.jar
kubernetes-client-api/7.1.0//kubernetes-client-api-7.1.0.jar
kubernetes-client/7.1.0//kubernetes-client-7.1.0.jar
kubernetes-httpclient-vertx/7.1.0//kubernetes-httpclient-vertx-7.1.0.jar
@@ -201,33 +201,33 @@ metrics-jmx/4.2.30//metrics-jmx-4.2.30.jar
metrics-json/4.2.30//metrics-json-4.2.30.jar
metrics-jvm/4.2.30//metrics-jvm-4.2.30.jar
minlog/1.3.0//minlog-1.3.0.jar
-netty-all/4.1.118.Final//netty-all-4.1.118.Final.jar
-netty-buffer/4.1.118.Final//netty-buffer-4.1.118.Final.jar
-netty-codec-dns/4.1.118.Final//netty-codec-dns-4.1.118.Final.jar
-netty-codec-http/4.1.118.Final//netty-codec-http-4.1.118.Final.jar
-netty-codec-http2/4.1.118.Final//netty-codec-http2-4.1.118.Final.jar
-netty-codec-socks/4.1.118.Final//netty-codec-socks-4.1.118.Final.jar
-netty-codec/4.1.118.Final//netty-codec-4.1.118.Final.jar
-netty-common/4.1.118.Final//netty-common-4.1.118.Final.jar
-netty-handler-proxy/4.1.118.Final//netty-handler-proxy-4.1.118.Final.jar
-netty-handler/4.1.118.Final//netty-handler-4.1.118.Final.jar
-netty-resolver-dns/4.1.118.Final//netty-resolver-dns-4.1.118.Final.jar
-netty-resolver/4.1.118.Final//netty-resolver-4.1.118.Final.jar
+netty-all/4.1.119.Final//netty-all-4.1.119.Final.jar
+netty-buffer/4.1.119.Final//netty-buffer-4.1.119.Final.jar
+netty-codec-dns/4.1.119.Final//netty-codec-dns-4.1.119.Final.jar
+netty-codec-http/4.1.119.Final//netty-codec-http-4.1.119.Final.jar
+netty-codec-http2/4.1.119.Final//netty-codec-http2-4.1.119.Final.jar
+netty-codec-socks/4.1.119.Final//netty-codec-socks-4.1.119.Final.jar
+netty-codec/4.1.119.Final//netty-codec-4.1.119.Final.jar
+netty-common/4.1.119.Final//netty-common-4.1.119.Final.jar
+netty-handler-proxy/4.1.119.Final//netty-handler-proxy-4.1.119.Final.jar
+netty-handler/4.1.119.Final//netty-handler-4.1.119.Final.jar
+netty-resolver-dns/4.1.119.Final//netty-resolver-dns-4.1.119.Final.jar
+netty-resolver/4.1.119.Final//netty-resolver-4.1.119.Final.jar
netty-tcnative-boringssl-static/2.0.70.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.70.Final-linux-aarch_64.jar
netty-tcnative-boringssl-static/2.0.70.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.70.Final-linux-x86_64.jar
netty-tcnative-boringssl-static/2.0.70.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.70.Final-osx-aarch_64.jar
netty-tcnative-boringssl-static/2.0.70.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.70.Final-osx-x86_64.jar
netty-tcnative-boringssl-static/2.0.70.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.70.Final-windows-x86_64.jar
netty-tcnative-classes/2.0.70.Final//netty-tcnative-classes-2.0.70.Final.jar
-netty-transport-classes-epoll/4.1.118.Final//netty-transport-classes-epoll-4.1.118.Final.jar
-netty-transport-classes-kqueue/4.1.118.Final//netty-transport-classes-kqueue-4.1.118.Final.jar
-netty-transport-native-epoll/4.1.118.Final/linux-aarch_64/netty-transport-native-epoll-4.1.118.Final-linux-aarch_64.jar
-netty-transport-native-epoll/4.1.118.Final/linux-riscv64/netty-transport-native-epoll-4.1.118.Final-linux-riscv64.jar
-netty-transport-native-epoll/4.1.118.Final/linux-x86_64/netty-transport-native-epoll-4.1.118.Final-linux-x86_64.jar
-netty-transport-native-kqueue/4.1.118.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.118.Final-osx-aarch_64.jar
-netty-transport-native-kqueue/4.1.118.Final/osx-x86_64/netty-transport-native-kqueue-4.1.118.Final-osx-x86_64.jar
-netty-transport-native-unix-common/4.1.118.Final//netty-transport-native-unix-common-4.1.118.Final.jar
-netty-transport/4.1.118.Final//netty-transport-4.1.118.Final.jar
+netty-transport-classes-epoll/4.1.119.Final//netty-transport-classes-epoll-4.1.119.Final.jar
+netty-transport-classes-kqueue/4.1.119.Final//netty-transport-classes-kqueue-4.1.119.Final.jar
+netty-transport-native-epoll/4.1.119.Final/linux-aarch_64/netty-transport-native-epoll-4.1.119.Final-linux-aarch_64.jar
+netty-transport-native-epoll/4.1.119.Final/linux-riscv64/netty-transport-native-epoll-4.1.119.Final-linux-riscv64.jar
+netty-transport-native-epoll/4.1.119.Final/linux-x86_64/netty-transport-native-epoll-4.1.119.Final-linux-x86_64.jar
+netty-transport-native-kqueue/4.1.119.Final/osx-aarch_64/netty-transport-native-kqueue-4.1.119.Final-osx-aarch_64.jar
+netty-transport-native-kqueue/4.1.119.Final/osx-x86_64/netty-transport-native-kqueue-4.1.119.Final-osx-x86_64.jar
+netty-transport-native-unix-common/4.1.119.Final//netty-transport-native-unix-common-4.1.119.Final.jar
+netty-transport/4.1.119.Final//netty-transport-4.1.119.Final.jar
objenesis/3.3//objenesis-3.3.jar
okhttp/3.12.12//okhttp-3.12.12.jar
okio/1.17.6//okio-1.17.6.jar
@@ -235,23 +235,23 @@ opencsv/2.3//opencsv-2.3.jar
opentracing-api/0.33.0//opentracing-api-0.33.0.jar
opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar
opentracing-util/0.33.0//opentracing-util-0.33.0.jar
-orc-core/2.1.0/shaded-protobuf/orc-core-2.1.0-shaded-protobuf.jar
-orc-format/1.0.0/shaded-protobuf/orc-format-1.0.0-shaded-protobuf.jar
-orc-mapreduce/2.1.0/shaded-protobuf/orc-mapreduce-2.1.0-shaded-protobuf.jar
-orc-shims/2.1.0//orc-shims-2.1.0.jar
+orc-core/2.1.1/shaded-protobuf/orc-core-2.1.1-shaded-protobuf.jar
+orc-format/1.1.0/shaded-protobuf/orc-format-1.1.0-shaded-protobuf.jar
+orc-mapreduce/2.1.1/shaded-protobuf/orc-mapreduce-2.1.1-shaded-protobuf.jar
+orc-shims/2.1.1//orc-shims-2.1.1.jar
oro/2.0.8//oro-2.0.8.jar
osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar
paranamer/2.8//paranamer-2.8.jar
-parquet-column/1.15.0//parquet-column-1.15.0.jar
-parquet-common/1.15.0//parquet-common-1.15.0.jar
-parquet-encoding/1.15.0//parquet-encoding-1.15.0.jar
-parquet-format-structures/1.15.0//parquet-format-structures-1.15.0.jar
-parquet-hadoop/1.15.0//parquet-hadoop-1.15.0.jar
-parquet-jackson/1.15.0//parquet-jackson-1.15.0.jar
+parquet-column/1.15.1//parquet-column-1.15.1.jar
+parquet-common/1.15.1//parquet-common-1.15.1.jar
+parquet-encoding/1.15.1//parquet-encoding-1.15.1.jar
+parquet-format-structures/1.15.1//parquet-format-structures-1.15.1.jar
+parquet-hadoop/1.15.1//parquet-hadoop-1.15.1.jar
+parquet-jackson/1.15.1//parquet-jackson-1.15.1.jar
pickle/1.5//pickle-1.5.jar
py4j/0.10.9.9//py4j-0.10.9.9.jar
remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar
-rocksdbjni/9.10.0//rocksdbjni-9.10.0.jar
+rocksdbjni/9.8.4//rocksdbjni-9.8.4.jar
scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar
scala-compiler/2.13.16//scala-compiler-2.13.16.jar
scala-library/2.13.16//scala-library-2.13.16.jar
@@ -259,7 +259,7 @@ scala-parallel-collections_2.13/1.2.0//scala-parallel-collections_2.13-1.2.0.jar
scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar
scala-reflect/2.13.16//scala-reflect-2.13.16.jar
scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar
-slf4j-api/2.0.16//slf4j-api-2.0.16.jar
+slf4j-api/2.0.17//slf4j-api-2.0.17.jar
snakeyaml-engine/2.9//snakeyaml-engine-2.9.jar
snakeyaml/2.3//snakeyaml-2.3.jar
snappy-java/1.1.10.7//snappy-java-1.1.10.7.jar
@@ -279,7 +279,7 @@ vertx-core/4.5.12//vertx-core-4.5.12.jar
vertx-web-client/4.5.12//vertx-web-client-4.5.12.jar
vertx-web-common/4.5.12//vertx-web-common-4.5.12.jar
wildfly-openssl/2.2.5.Final//wildfly-openssl-2.2.5.Final.jar
-xbean-asm9-shaded/4.26//xbean-asm9-shaded-4.26.jar
+xbean-asm9-shaded/4.27//xbean-asm9-shaded-4.27.jar
xmlschema-core/2.3.1//xmlschema-core-2.3.1.jar
xz/1.10//xz-1.10.jar
zjsonpatch/7.1.0//zjsonpatch-7.1.0.jar
diff --git a/dev/eslint.js b/dev/eslint.js
index 24b5170b436a9..abb06526fe966 100644
--- a/dev/eslint.js
+++ b/dev/eslint.js
@@ -40,6 +40,7 @@ module.exports = {
"dataTables.rowsGroup.js"
],
"parserOptions": {
- "sourceType": "module"
+ "sourceType": "module",
+ "ecmaVersion": "latest"
}
}
diff --git a/dev/is-changed.py b/dev/is-changed.py
index 1962e244d5dd7..4c3adf691327e 100755
--- a/dev/is-changed.py
+++ b/dev/is-changed.py
@@ -65,6 +65,35 @@ def main():
changed_files = identify_changed_files_from_git_commits(
os.environ["GITHUB_SHA"], target_ref=os.environ["GITHUB_PREV_SHA"]
)
+
+ if any(f.endswith(".jar") for f in changed_files):
+ with open(
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), "test-jars.txt")
+ ) as jarlist:
+ itrsect = set((line.strip() for line in jarlist.readlines())).intersection(
+ set(changed_files)
+ )
+ if len(itrsect) > 0:
+ raise SystemExit(
+ f"Cannot include jars in source codes ({', '.join(itrsect)}). "
+ "If they have to be added temporarily, "
+ "please add the file name into dev/test-jars.txt."
+ )
+
+ if any(f.endswith(".class") for f in changed_files):
+ with open(
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), "test-classes.txt")
+ ) as clslist:
+ itrsect = set((line.strip() for line in clslist.readlines())).intersection(
+ set(changed_files)
+ )
+ if len(itrsect) > 0:
+ raise SystemExit(
+ f"Cannot include class files in source codes ({', '.join(itrsect)}). "
+ "If they have to be added temporarily, "
+ "please add the file name into dev/test-classes.txt."
+ )
+
changed_modules = determine_modules_to_test(
determine_modules_for_files(changed_files), deduplicated=False
)
diff --git a/dev/lint-js b/dev/lint-js
index 1a94348b7430a..6ec66df47a736 100755
--- a/dev/lint-js
+++ b/dev/lint-js
@@ -45,7 +45,7 @@ if ! npm ls eslint > /dev/null; then
npm ci eslint
fi
-npx eslint -c "$SPARK_ROOT_DIR/dev/eslint.js" ${LINT_TARGET_FILES[@]} | tee "$LINT_JS_REPORT_FILE_NAME"
+npx eslint -c "$SPARK_ROOT_DIR/dev/eslint.js" "${LINT_TARGET_FILES[@]}" "$@" | tee "$LINT_JS_REPORT_FILE_NAME"
lint_status=$?
if [ "$lint_status" = "0" ] ; then
diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh
index 2a9fa4d4d0f2c..39f6477e07c97 100755
--- a/dev/make-distribution.sh
+++ b/dev/make-distribution.sh
@@ -38,15 +38,17 @@ MAKE_R=false
MAKE_SPARK_CONNECT=false
NAME=none
MVN="$SPARK_HOME/build/mvn"
+SBT_ENABLED=false
+SBT="$SPARK_HOME/build/sbt"
function exit_with_usage {
set +x
echo "make-distribution.sh - tool for making binary distributions of Spark"
echo ""
echo "usage:"
- cl_options="[--name] [--tgz] [--pip] [--r] [--connect] [--mvn ]"
- echo "make-distribution.sh $cl_options "
- echo "See Spark's \"Building Spark\" doc for correct Maven options."
+ cl_options="[--name] [--tgz] [--pip] [--r] [--connect] [--mvn ] [--sbt-enabled] [--sbt ]"
+ echo "make-distribution.sh $cl_options "
+ echo "See Spark's \"Building Spark\" doc for correct Maven/SBT options."
echo "SparkR is deprecated from Apache Spark 4.0.0 and will be removed in a future version."
echo ""
exit 1
@@ -71,6 +73,13 @@ while (( "$#" )); do
MVN="$2"
shift
;;
+ --sbt-enabled)
+ SBT_ENABLED=true
+ ;;
+ --sbt)
+ SBT="$2"
+ shift
+ ;;
--name)
NAME="$2"
shift
@@ -124,32 +133,25 @@ if [ $(command -v git) ]; then
unset GITREV
fi
-
-if [ ! "$(command -v "$MVN")" ] ; then
- echo -e "Could not locate Maven command: '$MVN'."
- echo -e "Specify the Maven command with the --mvn flag"
- exit -1;
+if [ "$SBT_ENABLED" == "true" && ! "$(command -v "$SBT")" ]; then
+ echo -e "Could not locate SBT command: '$SBT'."
+ echo -e "Specify the SBT command with the --sbt flag"
+ exit -1;
+elif [ ! "$(command -v "$MVN")" ]; then
+ echo -e "Could not locate Maven command: '$MVN'."
+ echo -e "Specify the Maven command with the --mvn flag"
+ exit -1;
fi
-VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ \
- | grep -v "INFO"\
- | grep -v "WARNING"\
- | tail -n 1)
-SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ \
- | grep -v "INFO"\
- | grep -v "WARNING"\
- | tail -n 1)
-SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ \
- | grep -v "INFO"\
- | grep -v "WARNING"\
- | tail -n 1)
-SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ \
- | grep -v "INFO"\
- | grep -v "WARNING"\
- | grep -F --count "hive";\
- # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\
- # because we use "set -o pipefail"
- echo -n)
+if [ "$SBT_ENABLED" == "true" ]; then
+ VERSION=$("$SBT" -no-colors "show version" | awk '/\[info\]/{ver=$2} END{print ver}')
+ SCALA_VERSION=$("$SBT" -no-colors "show scalaBinaryVersion" | awk '/\[info\]/{ver=$2} END{print ver}')
+ SPARK_HADOOP_VERSION=$("$SBT" -no-colors "show hadoopVersion" | awk '/\[info\]/{ver=$2} END{print ver}')
+else
+ VERSION=$("$MVN" help:evaluate -Dexpression=project.version "$@" -q -DforceStdout)
+ SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version "$@" -q -DforceStdout)
+ SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version "$@" -q -DforceStdout)
+fi
if [ "$NAME" == "none" ]; then
NAME=$SPARK_HADOOP_VERSION
@@ -166,18 +168,26 @@ fi
# Build uber fat JAR
cd "$SPARK_HOME"
-export MAVEN_OPTS="${MAVEN_OPTS:--Xss128m -Xmx4g -XX:ReservedCodeCacheSize=128m}"
-
-# Store the command as an array because $MVN variable might have spaces in it.
-# Normal quoting tricks don't work.
-# See: http://mywiki.wooledge.org/BashFAQ/050
-BUILD_COMMAND=("$MVN" clean package \
- -DskipTests \
- -Dmaven.javadoc.skip=true \
- -Dmaven.scaladoc.skip=true \
- -Dmaven.source.skip \
- -Dcyclonedx.skip=true \
- $@)
+if [ "$SBT_ENABLED" == "true" ] ; then
+ export NOLINT_ON_COMPILE=1
+ # Store the command as an array because $SBT variable might have spaces in it.
+ # Normal quoting tricks don't work.
+ # See: http://mywiki.wooledge.org/BashFAQ/050
+ BUILD_COMMAND=("$SBT" clean package $@)
+else
+ export MAVEN_OPTS="${MAVEN_OPTS:--Xss128m -Xmx4g -XX:ReservedCodeCacheSize=128m}"
+
+ # Store the command as an array because $MVN variable might have spaces in it.
+ # Normal quoting tricks don't work.
+ # See: http://mywiki.wooledge.org/BashFAQ/050
+ BUILD_COMMAND=("$MVN" clean package \
+ -DskipTests \
+ -Dmaven.javadoc.skip=true \
+ -Dmaven.scaladoc.skip=true \
+ -Dmaven.source.skip \
+ -Dcyclonedx.skip=true \
+ $@)
+fi
# Actually build the jar
echo -e "\nBuilding with..."
@@ -315,17 +325,17 @@ if [ "$MAKE_TGZ" == "true" ]; then
$TAR -czf "spark-$VERSION-bin-$NAME.tgz" -C "$SPARK_HOME" "$TARDIR_NAME"
rm -rf "$TARDIR"
if [[ "$MAKE_SPARK_CONNECT" == "true" ]]; then
- TARDIR_NAME=spark-$VERSION-bin-$NAME-spark-connect
+ TARDIR_NAME=spark-$VERSION-bin-$NAME-connect
TARDIR="$SPARK_HOME/$TARDIR_NAME"
rm -rf "$TARDIR"
cp -r "$DISTDIR" "$TARDIR"
# Set the Spark Connect system variable in these scripts to enable it by default.
- awk 'NR==1{print; print "export SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/pyspark" > tmp && cat tmp > "$TARDIR/bin/pyspark"
- awk 'NR==1{print; print "export SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/spark-shell" > tmp && cat tmp > "$TARDIR/bin/spark-shell"
- awk 'NR==1{print; print "export SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/spark-submit" > tmp && cat tmp > "$TARDIR/bin/spark-submit"
- awk 'NR==1{print; print "set SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/pyspark2.cmd" > tmp && cat tmp > "$TARDIR/bin/pyspark2.cmd"
- awk 'NR==1{print; print "set SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/spark-shell2.cmd" > tmp && cat tmp > "$TARDIR/bin/spark-shell2.cmd"
- awk 'NR==1{print; print "set SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/spark-submit2.cmd" > tmp && cat tmp > "$TARDIR/bin/spark-submit2.cmd"
+ awk 'NR==1{print; print "export SPARK_CONNECT_MODE=${SPARK_CONNECT_MODE:-1}"; next} {print}' "$TARDIR/bin/pyspark" > tmp && cat tmp > "$TARDIR/bin/pyspark"
+ awk 'NR==1{print; print "export SPARK_CONNECT_MODE=${SPARK_CONNECT_MODE:-1}"; next} {print}' "$TARDIR/bin/spark-shell" > tmp && cat tmp > "$TARDIR/bin/spark-shell"
+ awk 'NR==1{print; print "export SPARK_CONNECT_MODE=${SPARK_CONNECT_MODE:-1}"; next} {print}' "$TARDIR/bin/spark-submit" > tmp && cat tmp > "$TARDIR/bin/spark-submit"
+ awk 'NR==1{print; print "if [%SPARK_CONNECT_MODE%] == [] set SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/pyspark2.cmd" > tmp && cat tmp > "$TARDIR/bin/pyspark2.cmd"
+ awk 'NR==1{print; print "if [%SPARK_CONNECT_MODE%] == [] set SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/spark-shell2.cmd" > tmp && cat tmp > "$TARDIR/bin/spark-shell2.cmd"
+ awk 'NR==1{print; print "if [%SPARK_CONNECT_MODE%] == [] set SPARK_CONNECT_MODE=1"; next} {print}' "$TARDIR/bin/spark-submit2.cmd" > tmp && cat tmp > "$TARDIR/bin/spark-submit2.cmd"
rm tmp
$TAR -czf "$TARDIR_NAME.tgz" -C "$SPARK_HOME" "$TARDIR_NAME"
rm -rf "$TARDIR"
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index 415f468a11577..e9e82b010032f 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -566,6 +566,7 @@ def get_current_ref():
def initialize_jira():
global asf_jira
+ asf_jira = None
jira_server = {"server": JIRA_API_BASE}
if not JIRA_IMPORTED:
diff --git a/dev/scalastyle b/dev/scalastyle
index 9de1fd1c9d9d5..0428453b62c81 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -17,15 +17,13 @@
# limitations under the License.
#
-SPARK_PROFILES=${1:-"-Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive -Pvolcano -Pjvm-profiler -Phadoop-cloud"}
+SPARK_PROFILES=${1:-"-Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive -Pvolcano -Pjvm-profiler -Phadoop-cloud -Pdocker-integration-tests -Pkubernetes-integration-tests"}
# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file
# with failure (either resolution or compilation); the "q" makes SBT quit.
ERRORS=$(echo -e "q\n" \
| build/sbt \
${SPARK_PROFILES} \
- -Pdocker-integration-tests \
- -Pkubernetes-integration-tests \
scalastyle test:scalastyle \
| awk '{if($1~/error/)print}' \
)
@@ -36,3 +34,4 @@ if test ! -z "$ERRORS"; then
else
echo -e "Scalastyle checks passed."
fi
+
diff --git a/dev/spark-test-image/lint/Dockerfile b/dev/spark-test-image/lint/Dockerfile
index c3ffd7ba4e4b2..c61310e89f0d8 100644
--- a/dev/spark-test-image/lint/Dockerfile
+++ b/dev/spark-test-image/lint/Dockerfile
@@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image for Linter"
# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
LABEL org.opencontainers.image.version=""
-ENV FULL_REFRESH_DATE=20241112
+ENV FULL_REFRESH_DATE=20250312
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NONINTERACTIVE_SEEN=true
@@ -88,7 +88,7 @@ RUN python3.9 -m pip install \
'pandas' \
'pandas-stubs==1.2.0.53' \
'plotly>=4.8' \
- 'pyarrow>=18.0.0' \
+ 'pyarrow>=19.0.0' \
'pytest-mypy-plugins==1.9.3' \
'pytest==7.1.3' \
&& python3.9 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu \
diff --git a/dev/spark-test-image/numpy-213/Dockerfile b/dev/spark-test-image/numpy-213/Dockerfile
new file mode 100644
index 0000000000000..f3ce7b7091e12
--- /dev/null
+++ b/dev/spark-test-image/numpy-213/Dockerfile
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Image for building and testing Spark branches. Based on Ubuntu 22.04.
+# See also in https://hub.docker.com/_/ubuntu
+FROM ubuntu:jammy-20240911.1
+LABEL org.opencontainers.image.authors="Apache Spark project "
+LABEL org.opencontainers.image.licenses="Apache-2.0"
+LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark with Python 3.11 and Numpy 2.1.3"
+# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
+LABEL org.opencontainers.image.version=""
+
+ENV FULL_REFRESH_DATE=20250327
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV DEBCONF_NONINTERACTIVE_SEEN=true
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ ca-certificates \
+ curl \
+ gfortran \
+ git \
+ gnupg \
+ libcurl4-openssl-dev \
+ libfontconfig1-dev \
+ libfreetype6-dev \
+ libfribidi-dev \
+ libgit2-dev \
+ libharfbuzz-dev \
+ libjpeg-dev \
+ liblapack-dev \
+ libopenblas-dev \
+ libpng-dev \
+ libpython3-dev \
+ libssl-dev \
+ libtiff5-dev \
+ libxml2-dev \
+ openjdk-17-jdk-headless \
+ pkg-config \
+ qpdf \
+ tzdata \
+ software-properties-common \
+ wget \
+ zlib1g-dev
+
+# Install Python 3.11
+RUN add-apt-repository ppa:deadsnakes/ppa
+RUN apt-get update && apt-get install -y \
+ python3.11 \
+ && apt-get autoremove --purge -y \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+
+# Pin numpy==2.1.3
+ARG BASIC_PIP_PKGS="numpy==2.1.3 pyarrow==19.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+# Python deps for Spark Connect
+ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+
+# Install Python 3.11 packages
+RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
+RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.11 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
+ python3.11 -m pip cache purge
diff --git a/dev/spark-test-image/python-309/Dockerfile b/dev/spark-test-image/python-309/Dockerfile
index c8709205b8e38..7fd4b604225c8 100644
--- a/dev/spark-test-image/python-309/Dockerfile
+++ b/dev/spark-test-image/python-309/Dockerfile
@@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark wi
# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
LABEL org.opencontainers.image.version=""
-ENV FULL_REFRESH_DATE=20241205
+ENV FULL_REFRESH_DATE=20250312
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NONINTERACTIVE_SEEN=true
@@ -67,7 +67,7 @@ RUN apt-get update && apt-get install -y \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 googleapis-common-protos==1.65.0 graphviz==0.20.3"
diff --git a/dev/spark-test-image/python-310/Dockerfile b/dev/spark-test-image/python-310/Dockerfile
index a44a8b4a2691b..57c6c850a6219 100644
--- a/dev/spark-test-image/python-310/Dockerfile
+++ b/dev/spark-test-image/python-310/Dockerfile
@@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark wi
# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
LABEL org.opencontainers.image.version=""
-ENV FULL_REFRESH_DATE=20241205
+ENV FULL_REFRESH_DATE=20250312
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NONINTERACTIVE_SEEN=true
@@ -63,7 +63,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 googleapis-common-protos==1.65.0 graphviz==0.20.3"
diff --git a/dev/spark-test-image/python-311-classic-only/Dockerfile b/dev/spark-test-image/python-311-classic-only/Dockerfile
new file mode 100644
index 0000000000000..8f2ec0b0dd1f5
--- /dev/null
+++ b/dev/spark-test-image/python-311-classic-only/Dockerfile
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Image for building and testing Spark branches. Based on Ubuntu 22.04.
+# See also in https://hub.docker.com/_/ubuntu
+FROM ubuntu:jammy-20240911.1
+LABEL org.opencontainers.image.authors="Apache Spark project "
+LABEL org.opencontainers.image.licenses="Apache-2.0"
+LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark Classic with Python 3.11"
+# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
+LABEL org.opencontainers.image.version=""
+
+ENV FULL_REFRESH_DATE=20250424
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV DEBCONF_NONINTERACTIVE_SEEN=true
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ ca-certificates \
+ curl \
+ gfortran \
+ git \
+ gnupg \
+ libcurl4-openssl-dev \
+ libfontconfig1-dev \
+ libfreetype6-dev \
+ libfribidi-dev \
+ libgit2-dev \
+ libharfbuzz-dev \
+ libjpeg-dev \
+ liblapack-dev \
+ libopenblas-dev \
+ libpng-dev \
+ libpython3-dev \
+ libssl-dev \
+ libtiff5-dev \
+ libxml2-dev \
+ openjdk-17-jdk-headless \
+ pkg-config \
+ qpdf \
+ tzdata \
+ software-properties-common \
+ wget \
+ zlib1g-dev
+
+# Install Python 3.11
+RUN add-apt-repository ppa:deadsnakes/ppa
+RUN apt-get update && apt-get install -y \
+ python3.11 \
+ && apt-get autoremove --purge -y \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+
+ARG BASIC_PIP_PKGS="numpy pyarrow>=19.0.0 pandas==2.2.3 plotly<6.0.0 matplotlib openpyxl memory-profiler>=0.61.0 mlflow>=2.8.1 scipy scikit-learn>=1.3.2"
+ARG TEST_PIP_PKGS="coverage unittest-xml-reporting"
+
+# Install Python 3.11 packages
+RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
+RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.11 -m pip install $BASIC_PIP_PKGS $TEST_PIP_PKGS && \
+ python3.11 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
+ python3.11 -m pip install deepspeed torcheval && \
+ python3.11 -m pip cache purge
diff --git a/dev/spark-test-image/python-311/Dockerfile b/dev/spark-test-image/python-311/Dockerfile
index 646d5a63fc510..1a2caa483785b 100644
--- a/dev/spark-test-image/python-311/Dockerfile
+++ b/dev/spark-test-image/python-311/Dockerfile
@@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark wi
# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
LABEL org.opencontainers.image.version=""
-ENV FULL_REFRESH_DATE=20241212
+ENV FULL_REFRESH_DATE=20250312
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NONINTERACTIVE_SEEN=true
@@ -67,7 +67,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 googleapis-common-protos==1.65.0 graphviz==0.20.3"
diff --git a/dev/spark-test-image/python-312/Dockerfile b/dev/spark-test-image/python-312/Dockerfile
index c2c9fe211695a..f64e3e3ba30ce 100644
--- a/dev/spark-test-image/python-312/Dockerfile
+++ b/dev/spark-test-image/python-312/Dockerfile
@@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark wi
# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
LABEL org.opencontainers.image.version=""
-ENV FULL_REFRESH_DATE=20241206
+ENV FULL_REFRESH_DATE=20250312
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NONINTERACTIVE_SEEN=true
@@ -67,7 +67,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 googleapis-common-protos==1.65.0 graphviz==0.20.3"
diff --git a/dev/spark-test-image/python-313-nogil/Dockerfile b/dev/spark-test-image/python-313-nogil/Dockerfile
new file mode 100644
index 0000000000000..cee6a4cca4d33
--- /dev/null
+++ b/dev/spark-test-image/python-313-nogil/Dockerfile
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Image for building and testing Spark branches. Based on Ubuntu 22.04.
+# See also in https://hub.docker.com/_/ubuntu
+FROM ubuntu:jammy-20240911.1
+LABEL org.opencontainers.image.authors="Apache Spark project "
+LABEL org.opencontainers.image.licenses="Apache-2.0"
+LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark with Python 3.13 (no GIL)"
+# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
+LABEL org.opencontainers.image.version=""
+
+ENV FULL_REFRESH_DATE=20250407
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV DEBCONF_NONINTERACTIVE_SEEN=true
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ ca-certificates \
+ curl \
+ gfortran \
+ git \
+ gnupg \
+ libcurl4-openssl-dev \
+ libfontconfig1-dev \
+ libfreetype6-dev \
+ libfribidi-dev \
+ libgit2-dev \
+ libharfbuzz-dev \
+ libjpeg-dev \
+ liblapack-dev \
+ libopenblas-dev \
+ libpng-dev \
+ libpython3-dev \
+ libssl-dev \
+ libtiff5-dev \
+ libxml2-dev \
+ openjdk-17-jdk-headless \
+ pkg-config \
+ qpdf \
+ tzdata \
+ software-properties-common \
+ wget \
+ zlib1g-dev
+
+# Install Python 3.13 (no GIL)
+RUN add-apt-repository ppa:deadsnakes/ppa
+RUN apt-get update && apt-get install -y \
+ python3.13-nogil \
+ && apt-get autoremove --purge -y \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+
+ARG BASIC_PIP_PKGS="numpy pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+
+
+# Install Python 3.13 packages
+RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13t
+# TODO: Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS when it supports Python 3.13 free threaded
+# TODO: Add lxml, grpcio, grpcio-status back when they support Python 3.13 free threaded
+RUN python3.13t -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.13t -m pip install numpy>=2.1 pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy coverage matplotlib openpyxl jinja2 && \
+ python3.13t -m pip cache purge
diff --git a/dev/spark-test-image/python-313/Dockerfile b/dev/spark-test-image/python-313/Dockerfile
index 6ad741d890da7..aede82ac7d78c 100644
--- a/dev/spark-test-image/python-313/Dockerfile
+++ b/dev/spark-test-image/python-313/Dockerfile
@@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark wi
# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
LABEL org.opencontainers.image.version=""
-ENV FULL_REFRESH_DATE=20241210
+ENV FULL_REFRESH_DATE=20250312
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NONINTERACTIVE_SEEN=true
@@ -67,13 +67,14 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 googleapis-common-protos==1.65.0 graphviz==0.20.3"
-
# Install Python 3.13 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13
-# TODO(SPARK-49862) Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13 image when it supports Python 3.13
RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
-RUN python3.13 -m pip install numpy>=2.1 pyarrow>=18.0.0 six==1.16.0 pandas==2.2.3 scipy coverage matplotlib openpyxl grpcio==1.67.0 grpcio-status==1.67.0 lxml jinja2 && \
+RUN python3.13 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \
+ python3.13 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
+ python3.13 -m pip install torcheval && \
python3.13 -m pip cache purge
diff --git a/dev/spark-test-image/python-minimum/Dockerfile b/dev/spark-test-image/python-minimum/Dockerfile
index 82e2508ec6e32..59d9ebed4e40f 100644
--- a/dev/spark-test-image/python-minimum/Dockerfile
+++ b/dev/spark-test-image/python-minimum/Dockerfile
@@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark wi
# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
LABEL org.opencontainers.image.version=""
-ENV FULL_REFRESH_DATE=20241223
+ENV FULL_REFRESH_DATE=20250327
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NONINTERACTIVE_SEEN=true
diff --git a/dev/spark-test-image/python-ps-minimum/Dockerfile b/dev/spark-test-image/python-ps-minimum/Dockerfile
index 913da06c551ca..0cdf1fa6aa1f1 100644
--- a/dev/spark-test-image/python-ps-minimum/Dockerfile
+++ b/dev/spark-test-image/python-ps-minimum/Dockerfile
@@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For Pandas API
# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
LABEL org.opencontainers.image.version=""
-ENV FULL_REFRESH_DATE=20250102
+ENV FULL_REFRESH_DATE=20250327
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NONINTERACTIVE_SEEN=true
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 6d1ffe0afdd13..497963e76109c 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -510,6 +510,7 @@ def __hash__(self):
"pyspark.sql.observation",
"pyspark.sql.tvf",
# unittests
+ "pyspark.sql.tests.test_artifact",
"pyspark.sql.tests.test_catalog",
"pyspark.sql.tests.test_column",
"pyspark.sql.tests.test_conf",
@@ -1035,12 +1036,14 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_plan",
"pyspark.sql.tests.connect.test_connect_basic",
"pyspark.sql.tests.connect.test_connect_dataframe_property",
+ "pyspark.sql.tests.connect.test_connect_channel",
"pyspark.sql.tests.connect.test_connect_error",
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_collection",
"pyspark.sql.tests.connect.test_connect_column",
"pyspark.sql.tests.connect.test_connect_creation",
"pyspark.sql.tests.connect.test_connect_readwriter",
+ "pyspark.sql.tests.connect.test_connect_retry",
"pyspark.sql.tests.connect.test_connect_session",
"pyspark.sql.tests.connect.test_connect_stat",
"pyspark.sql.tests.connect.test_parity_datasources",
@@ -1114,6 +1117,7 @@ def __hash__(self):
# ml doctests
"pyspark.ml.connect.functions",
# ml unittests
+ "pyspark.ml.tests.connect.test_connect_cache",
"pyspark.ml.tests.connect.test_connect_function",
"pyspark.ml.tests.connect.test_parity_torch_distributor",
"pyspark.ml.tests.connect.test_parity_torch_data_loader",
@@ -1463,7 +1467,10 @@ def __hash__(self):
],
python_test_goals=[
# unittests
+ "pyspark.errors.tests.test_connect_errors_conversion",
"pyspark.errors.tests.test_errors",
+ "pyspark.errors.tests.test_traceback",
+ "pyspark.errors.tests.connect.test_parity_traceback",
],
)
diff --git a/dev/test-classes.txt b/dev/test-classes.txt
new file mode 100644
index 0000000000000..5315c970c5bab
--- /dev/null
+++ b/dev/test-classes.txt
@@ -0,0 +1,8 @@
+repl/src/test/resources/IntSumUdf.class
+sql/core/src/test/resources/artifact-tests/Hello.class
+sql/core/src/test/resources/artifact-tests/IntSumUdf.class
+sql/core/src/test/resources/artifact-tests/smallClassFile.class
+sql/connect/common/src/test/resources/artifact-tests/Hello.class
+sql/core/src/test/resources/artifact-tests/HelloWithPackage.class
+sql/connect/common/src/test/resources/artifact-tests/smallClassFile.class
+sql/connect/common/src/test/resources/artifact-tests/smallClassFileDup.class
diff --git a/dev/test-jars.txt b/dev/test-jars.txt
new file mode 100644
index 0000000000000..bd8fc93bc9f0f
--- /dev/null
+++ b/dev/test-jars.txt
@@ -0,0 +1,17 @@
+core/src/test/resources/TestHelloV2_2.13.jar
+core/src/test/resources/TestHelloV3_2.13.jar
+core/src/test/resources/TestUDTF.jar
+data/artifact-tests/junitLargeJar.jar
+data/artifact-tests/smallJar.jar
+sql/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar
+sql/connect/client/jvm/src/test/resources/udf2.13.jar
+sql/connect/common/src/test/resources/artifact-tests/junitLargeJar.jar
+sql/connect/common/src/test/resources/artifact-tests/smallJar.jar
+sql/core/src/test/resources/SPARK-33084.jar
+sql/core/src/test/resources/artifact-tests/udf_noA.jar
+sql/hive-thriftserver/src/test/resources/TestUDTF.jar
+sql/hive/src/test/noclasspath/hive-test-udfs.jar
+sql/hive/src/test/resources/SPARK-21101-1.0.jar
+sql/hive/src/test/resources/TestUDTF.jar
+sql/hive/src/test/resources/data/files/TestSerDe.jar
+sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.13.jar
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 518d936c3c85c..1a2da3b01726b 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -63,12 +63,19 @@ Other build examples can be found below.
To create a Spark distribution like those distributed by the
[Spark Downloads](https://spark.apache.org/downloads.html) page, and that is laid out so as
-to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured
-with Maven profile settings and so on like the direct Maven build. Example:
+to be runnable, use `./dev/make-distribution.sh` in the project root directory. By default,
+it uses Maven as building tool, and can be configured with Maven profile settings and so on
+like the direct Maven build. Example:
./dev/make-distribution.sh --name custom-spark --pip --r --tgz -Psparkr -Phive -Phive-thriftserver -Pyarn -Pkubernetes
-This will build Spark distribution along with Python pip and R packages. For more information on usage, run `./dev/make-distribution.sh --help`
+This will build Spark distribution along with Python pip and R packages.
+
+To switch to SBT (experimental), use `--sbt-enabled`. Example:
+
+ ./dev/make-distribution.sh --name custom-spark --pip --r --tgz --sbt-enabled -Psparkr -Phive -Phive-thriftserver -Pyarn -Pkubernetes
+
+For more information on usage, run `./dev/make-distribution.sh --help`
## Specifying the Hadoop Version and Enabling YARN
@@ -261,7 +268,7 @@ On Linux, this can be done by `sudo service docker start`.
or
- ./build/sbt docker-integration-tests/test
+ ./build/sbt -Pdocker-integration-tests docker-integration-tests/test
3.4.1
@@ -137,14 +137,15 @@
3.9.010.16.1.1
- 1.15.0
- 2.1.0
+ 1.15.1
+ 2.1.1shaded-protobuf
- 11.0.24
+ 11.0.255.0.04.0.10.10.0
+ 4.0.32.5.32.0.81.11.655
- 2.25.53
+ 2.29.520.12.8
@@ -168,7 +169,7 @@
3.6.13.2.2
- 4.4
+ 4.5.02.13.162.132.2.0
@@ -187,7 +188,7 @@
3.0.31.18.01.27.1
- 2.18.0
+ 2.19.02.6
@@ -198,7 +199,7 @@
33.4.0-jre2.11.03.1.9
- 3.0.16
+ 3.0.172.13.03.5.23.0.0
@@ -213,17 +214,17 @@
1.9.01.801.16.0
- 6.1.1
- 4.1.118.Final
+ 6.2.0
+ 4.1.119.Final2.0.70.Final76.1
- 5.11.4
- 1.11.4
+ 5.12.2
+ 1.12.2
- 0.13.3
+ 0.14.0
+ NO
+
org.apache.spark.tags.ChromeUITest
@@ -336,10 +344,11 @@
12.8.1.jre1123.6.0.24.102.7.1
- 3.22.0
+ 3.23.220.00.00.39${project.version}
+ 3.4.2
@@ -478,12 +487,42 @@
org.apache.xbeanxbean-asm7-shaded
+
+ com.esotericsoftware
+ kryo-shaded
+ com.twitterchill-java${chill.version}
+
+
+ com.esotericsoftware
+ kryo-shaded
+
+
+
+
+ com.esotericsoftware
+ kryo-shaded
+ ${kryo.version}
+
+
+ org.objenesis
+ objenesis
+
+
+
+
+
+ org.objenesis
+ objenesis
+ 3.3com.github.jnr
@@ -497,7 +536,7 @@
org.apache.xbeanxbean-asm9-shaded
- 4.26
+ 4.27true
@@ -2811,6 +2850,7 @@
1${test.java.home}-DmyKey=yourValue
+ ${test.objc.disable.initialize.fork.safety}file:src/test/resources/log4j2.properties
@@ -2865,6 +2905,7 @@
${scala.binary.version}1${test.java.home}
+ ${test.objc.disable.initialize.fork.safety}file:src/test/resources/log4j2.properties
@@ -2898,7 +2939,7 @@
org.apache.maven.pluginsmaven-jar-plugin
- 3.4.0
+ ${maven-jar-plugin.version}org.apache.maven.plugins
@@ -2925,7 +2966,7 @@
org.apache.maven.pluginsmaven-clean-plugin
- 3.4.0
+ 3.4.1
@@ -3025,12 +3066,12 @@
org.apache.maven.pluginsmaven-install-plugin
- 3.1.2
+ 3.1.4org.apache.maven.pluginsmaven-deploy-plugin
- 3.1.2
+ 3.1.43
@@ -3542,6 +3583,17 @@
+
+ macOS
+
+ YES
+
+
+
+ mac
+
+
+ jdwp-test-debug
@@ -3588,7 +3640,7 @@
org.apache.maven.pluginsmaven-jar-plugin
- 3.4.0
+ ${maven-jar-plugin.version}test-jar
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 1a3936195f9bf..d89bb285ed8dc 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -35,6 +35,8 @@ import com.typesafe.tools.mima.core.*
object MimaExcludes {
lazy val v41excludes = v40excludes ++ Seq(
+ // [SPARK-51261][ML][CONNECT] Introduce model size estimation to control ml cache
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.getSizeInBytes")
)
// Exclude rules for 4.0.x from 3.5.0
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 85c5474205d37..d282fa2611c2f 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -412,6 +412,8 @@ object SparkBuild extends PomBuild {
/* Hive console settings */
enable(Hive.settings)(hive)
+ enable(HiveThriftServer.settings)(hiveThriftServer)
+
enable(SparkConnectCommon.settings)(connectCommon)
enable(SparkConnect.settings)(connect)
enable(SparkConnectClient.settings)(connectClient)
@@ -838,6 +840,8 @@ object SparkConnectClient {
testOnly := ((Test / testOnly) dependsOn (buildTestDeps)).evaluated,
+ (Test / javaOptions) += "-Darrow.memory.debug.allocator=true",
+
(assembly / test) := { },
(assembly / logLevel) := Level.Info,
@@ -1003,7 +1007,7 @@ object KubernetesIntegrationTests {
rDockerFile = ""
}
val extraOptions = if (javaImageTag.isDefined) {
- Seq("-b", s"java_image_tag=$javaImageTag")
+ Seq("-b", s"java_image_tag=${javaImageTag.get}")
} else {
Seq("-f", s"$dockerFile")
}
@@ -1203,6 +1207,14 @@ object Hive {
)
}
+object HiveThriftServer {
+ lazy val settings = Seq(
+ excludeDependencies ++= Seq(
+ ExclusionRule("org.apache.hive", "hive-llap-common"),
+ ExclusionRule("org.apache.hive", "hive-llap-client"))
+ )
+}
+
object YARN {
val genConfigProperties = TaskKey[Unit]("gen-config-properties",
"Generate config.properties which contains a setting whether Hadoop is provided or not")
@@ -1578,15 +1590,24 @@ object TestSettings {
fork := true,
// Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes
// launched by the tests have access to the correct test-time classpath.
- (Test / envVars) ++= Map(
- "SPARK_DIST_CLASSPATH" ->
- (Test / fullClasspath).value.files.map(_.getAbsolutePath)
- .mkString(File.pathSeparator).stripSuffix(File.pathSeparator),
- "SPARK_PREPEND_CLASSES" -> "1",
- "SPARK_SCALA_VERSION" -> scalaBinaryVersion.value,
- "SPARK_TESTING" -> "1",
- "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home")),
- "SPARK_BEELINE_OPTS" -> "-DmyKey=yourValue"),
+ (Test / envVars) ++= {
+ val baseEnvVars = Map(
+ "SPARK_DIST_CLASSPATH" ->
+ (Test / fullClasspath).value.files.map(_.getAbsolutePath)
+ .mkString(File.pathSeparator).stripSuffix(File.pathSeparator),
+ "SPARK_PREPEND_CLASSES" -> "1",
+ "SPARK_SCALA_VERSION" -> scalaBinaryVersion.value,
+ "SPARK_TESTING" -> "1",
+ "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home")),
+ "SPARK_BEELINE_OPTS" -> "-DmyKey=yourValue"
+ )
+
+ if (sys.props("os.name").contains("Mac OS X")) {
+ baseEnvVars + ("OBJC_DISABLE_INITIALIZE_FORK_SAFETY" -> "YES")
+ } else {
+ baseEnvVars
+ }
+ },
// Copy system properties to forked JVMs so that tests know proxy settings
(Test / javaOptions) ++= {
@@ -1694,7 +1715,7 @@ object TestSettings {
(Test / testOptions) += Tests.Argument(TestFrameworks.ScalaTest, "-W", "120", "300"),
(Test / testOptions) += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"),
// Enable Junit testing.
- libraryDependencies += "com.github.sbt.junit" % "jupiter-interface" % "0.13.3" % "test",
+ libraryDependencies += "com.github.sbt.junit" % "jupiter-interface" % "0.14.0" % "test",
// `parallelExecutionInTest` controls whether test suites belonging to the same SBT project
// can run in parallel with one another. It does NOT control whether tests execute in parallel
// within the same JVM (which is controlled by `testForkedParallel`) or whether test cases
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 2885dd2fc5fb7..579020e5af0ce 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -33,14 +33,14 @@ addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0")
addSbtPlugin("io.spray" % "sbt-revolver" % "0.10.0")
-libraryDependencies += "org.ow2.asm" % "asm" % "9.7.1"
+libraryDependencies += "org.ow2.asm" % "asm" % "9.8"
-libraryDependencies += "org.ow2.asm" % "asm-commons" % "9.7.1"
+libraryDependencies += "org.ow2.asm" % "asm-commons" % "9.8"
addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.3")
addSbtPlugin("com.github.sbt" % "sbt-pom-reader" % "2.4.0")
-addSbtPlugin("com.github.sbt.junit" % "sbt-jupiter-interface" % "0.13.3")
+addSbtPlugin("com.github.sbt.junit" % "sbt-jupiter-interface" % "0.14.0")
addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7")
diff --git a/python/docs/source/index.rst b/python/docs/source/index.rst
index 72a846290fe9e..2e102c8de71e7 100644
--- a/python/docs/source/index.rst
+++ b/python/docs/source/index.rst
@@ -36,6 +36,18 @@ to enable processing and analysis of data at any size for everyone familiar with
PySpark supports all of Spark's features such as Spark SQL,
DataFrames, Structured Streaming, Machine Learning (MLlib) and Spark Core.
+.. list-table::
+ :widths: 10 80 10
+ :header-rows: 0
+ :class: borderless spec_table
+
+ * -
+ - .. image:: ../../../docs/img/pyspark-python_spark_connect_client.png
+ :target: getting_started/quickstart_connect.html
+ :width: 100%
+ :alt: Python Spark Connect Client
+ -
+
.. list-table::
:widths: 10 20 20 20 20 10
:header-rows: 0
@@ -72,6 +84,19 @@ DataFrames, Structured Streaming, Machine Learning (MLlib) and Spark Core.
:alt: Spark Core and RDDs
-
+.. _Index Page - Python Spark Connect Client:
+
+**Python Spark Connect Client**
+
+Spark Connect is a client-server architecture within Apache Spark that
+enables remote connectivity to Spark clusters from any application.
+PySpark provides the client for the Spark Connect server, allowing
+Spark to be used as a service.
+
+- :ref:`/getting_started/quickstart_connect.ipynb`
+- |binder_connect|_
+- `Spark Connect Overview `_
+
.. _Index Page - Spark SQL and DataFrames:
**Spark SQL and DataFrames**
diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst
index 976aef7cb68b6..3f2b47c2859e4 100644
--- a/python/docs/source/migration_guide/pyspark_upgrade.rst
+++ b/python/docs/source/migration_guide/pyspark_upgrade.rst
@@ -19,6 +19,11 @@
Upgrading PySpark
==================
+Upgrading from PySpark 4.0 to 4.1
+---------------------------------
+
+* In Spark 4.1, Arrow-optimized Python UDF supports UDT input / output instead of falling back to the regular UDF. To restore the legacy behavior, set ``spark.sql.execution.pythonUDF.arrow.legacy.fallbackOnUDT`` to ``true``.
+
Upgrading from PySpark 3.5 to 4.0
---------------------------------
@@ -75,8 +80,6 @@ Upgrading from PySpark 3.5 to 4.0
* In Spark 4.0, ``compute.ops_on_diff_frames`` is on by default. To restore the previous behavior, set ``compute.ops_on_diff_frames`` to ``false``.
* In Spark 4.0, the data type ``YearMonthIntervalType`` in ``DataFrame.collect`` no longer returns the underlying integers. To restore the previous behavior, set ``PYSPARK_YM_INTERVAL_LEGACY`` environment variable to ``1``.
* In Spark 4.0, items other than functions (e.g. ``DataFrame``, ``Column``, ``StructType``) have been removed from the wildcard import ``from pyspark.sql.functions import *``, you should import these items from proper modules (e.g. ``from pyspark.sql import DataFrame, Column``, ``from pyspark.sql.types import StructType``).
-* In Spark 4.0, ``spark.sql.execution.pythonUDF.arrow.enabled`` is enabled by default. If users have PyArrow and pandas installed in their local and Spark Cluster, it automatically optimizes the regular Python UDFs with Arrow. To turn off the Arrow optimization, set ``spark.sql.execution.pythonUDF.arrow.enabled`` to ``false``.
-* In Spark 4.0, ``spark.sql.execution.arrow.pyspark.enabled`` is enabled by default. If users have PyArrow and pandas installed in their local and Spark Cluster, it automatically makes use of Apache Arrow for columnar data transfers in PySpark. This optimization applies to ``pyspark.sql.DataFrame.toPandas`` and ``pyspark.sql.SparkSession.createDataFrame`` when its input is a Pandas DataFrame or a NumPy ndarray. To turn off the Arrow optimization, set ``spark.sql.execution.arrow.pyspark.enabled`` to ``false``.
Upgrading from PySpark 3.3 to 3.4
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index 9d45b2bc2be37..a3a2e11daf2e1 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -132,6 +132,7 @@ Mathematical Functions
radians
rand
randn
+ random
rint
round
sec
@@ -164,6 +165,7 @@ String Functions
char
char_length
character_length
+ chr
collate
collation
concat_ws
@@ -192,6 +194,7 @@ String Functions
overlay
position
printf
+ quote
randstr
regexp_count
regexp_extract
@@ -631,6 +634,7 @@ Misc Functions
try_reflect
typeof
user
+ uuid
version
diff --git a/python/docs/source/user_guide/pandas_on_spark/best_practices.rst b/python/docs/source/user_guide/pandas_on_spark/best_practices.rst
index 14c04aa622ecf..b819c50bef7f0 100644
--- a/python/docs/source/user_guide/pandas_on_spark/best_practices.rst
+++ b/python/docs/source/user_guide/pandas_on_spark/best_practices.rst
@@ -242,6 +242,80 @@ to handle large data in production, make it distributed by configuring the defau
See `Default Index Type `_ for more details about configuring default index.
+Handling index misalignment with ``distributed-sequence``
+----------------------------------------------------------
+
+While ``distributed-sequence`` ensures a globally sequential index, it does **not** guarantee that the same row-to-index mapping is maintained across different operations.
+Operations such as ``apply()``, ``groupby()``, or ``transform()`` may cause the index to be regenerated, leading to misalignment between rows and computed values.
+
+Issue example with ``apply()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+In the following example, we load a dataset where ``record_id`` acts as a unique identifier, and we compute the duration (number of business days) using an ``apply()`` function.
+However, due to ``distributed-sequence`` index regeneration during ``apply()``, the results may be assigned to incorrect rows.
+
+.. code-block:: python
+
+ import pyspark.pandas as ps
+ import numpy as np
+
+ ps.set_option('compute.default_index_type', 'distributed-sequence')
+
+ df = ps.DataFrame({
+ 'record_id': ["RECORD_1001", "RECORD_1002"],
+ 'start_date': ps.to_datetime(["2024-01-01", "2024-01-02"]),
+ 'end_date': ps.to_datetime(["2024-01-01", "2024-01-03"])
+ })
+
+ df['duration'] = df.apply(lambda x: np.busday_count(x['start_date'].date(), x['end_date'].date()), axis=1)
+
+Expected output:
+
+.. code-block::
+
+ record_id start_date end_date duration
+ RECORD_1001 2024-01-01 2024-01-01 0
+ RECORD_1002 2024-01-02 2024-01-03 1
+
+However, due to the ``distributed-sequence`` index being re-generated during ``apply()``, the resulting DataFrame might look like this:
+
+.. code-block::
+
+ record_id start_date end_date duration
+ RECORD_1002 2024-01-02 2024-01-03 0 # Wrong mapping!
+ RECORD_1001 2024-01-01 2024-01-01 1 # Wrong mapping!
+
+Best practices to prevent index misalignment
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+To ensure the row-to-index mapping remains consistent, consider the following approaches:
+
+1. **Explicitly set an index column before applying functions:**
+
+ .. code-block:: python
+
+ df = df.set_index("record_id") # Ensure the index is explicitly set
+ df['duration'] = df.apply(lambda x: np.busday_count(x['start_date'].date(), x['end_date'].date()), axis=1)
+
+2. **Persist the DataFrame before applying functions to maintain row ordering:**
+
+ .. code-block:: python
+
+ df = df.spark.persist()
+ df['duration'] = df.apply(lambda x: np.busday_count(x['start_date'].date(), x['end_date'].date()), axis=1)
+
+3. **Use the sequence index type instead (be aware of potential performance trade-offs):**
+
+ .. code-block:: python
+
+ ps.set_option('compute.default_index_type', 'sequence')
+
+If your application requires strict row-to-index mapping, consider using one of the above approaches rather than relying on the default ``distributed-sequence`` index.
+
+For more information, refer to `Default Index Type `_
+
+
+
Reduce the operations on different DataFrame/Series
---------------------------------------------------
diff --git a/python/docs/source/user_guide/pandas_on_spark/options.rst b/python/docs/source/user_guide/pandas_on_spark/options.rst
index e8fffea7e33be..14164b771e3f2 100644
--- a/python/docs/source/user_guide/pandas_on_spark/options.rst
+++ b/python/docs/source/user_guide/pandas_on_spark/options.rst
@@ -208,6 +208,19 @@ This is conceptually equivalent to the PySpark example as below:
>>> spark_df.rdd.zipWithIndex().map(lambda p: p[1]).collect()
[0, 1, 2]
+.. warning::
+ Unlike ``sequence``, since ``distributed-sequence`` is executed in a distributed environment,
+ the rows corresponding to each index may vary although the index itself still remains globally sequential.
+
+ This happens because the rows are distributed across multiple partitions and nodes,
+ leading to indeterministic row-to-index mappings when the data is loaded.
+
+ Additionally, when using operations such as ``apply()``, ``groupby()``, or ``transform()``,
+ a new ``distributed-sequence`` index may be generated, which does not necessarily match the original index of the DataFrame.
+ This can result in misaligned row-to-index mappings, leading to incorrect calculations.
+
+ To avoid this issue, see `Handling index misalignment with distributed-sequence `_
+
**distributed**: It implements a monotonically increasing sequence simply by using
PySpark's `monotonically_increasing_id` function in a fully distributed manner. The
values are indeterministic. If the index does not have to be a sequence that increases
diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst
index b9e389f8fe7dd..ffff59f136cbe 100644
--- a/python/docs/source/user_guide/sql/arrow_pandas.rst
+++ b/python/docs/source/user_guide/sql/arrow_pandas.rst
@@ -356,8 +356,8 @@ Arrow Python UDFs are user defined functions that are executed row-by-row, utili
transfer and serialization. To define an Arrow Python UDF, you can use the :meth:`udf` decorator or wrap the function
with the :meth:`udf` method, ensuring the ``useArrow`` parameter is set to True. Additionally, you can enable Arrow
optimization for Python UDFs throughout the entire SparkSession by setting the Spark configuration
-``spark.sql.execution.pythonUDF.arrow.enabled`` to true, which is the default. It's important to note that the Spark
-configuration takes effect only when ``useArrow`` is either not set or set to None.
+``spark.sql.execution.pythonUDF.arrow.enabled`` to true. It's important to note that the Spark configuration takes
+effect only when ``useArrow`` is either not set or set to None.
The type hints for Arrow Python UDFs should be specified in the same way as for default, pickled Python UDFs.
@@ -434,7 +434,7 @@ working with timestamps in ``pandas_udf``\s to get the best performance, see
Recommended Pandas and PyArrow Versions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-For usage with pyspark.sql, the minimum supported versions of Pandas is 2.0.0 and PyArrow is 10.0.0.
+For usage with pyspark.sql, the minimum supported versions of Pandas is 2.0.0 and PyArrow is 11.0.0.
Higher versions may be used, however, compatibility and data correctness can not be guaranteed and should
be verified by the user.
diff --git a/python/docs/source/user_guide/sql/type_conversions.rst b/python/docs/source/user_guide/sql/type_conversions.rst
index 80f8aa83db7eb..2f13701995ef2 100644
--- a/python/docs/source/user_guide/sql/type_conversions.rst
+++ b/python/docs/source/user_guide/sql/type_conversions.rst
@@ -57,7 +57,7 @@ are listed below:
- Default
* - spark.sql.execution.pythonUDF.arrow.enabled
- Enable PyArrow in PySpark. See more `here `_.
- - True
+ - False
* - spark.sql.pyspark.inferNestedDictAsStruct.enabled
- When enabled, nested dictionaries are inferred as StructType. Otherwise, they are inferred as MapType.
- False
diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py
index 8cc2eb182f8f3..da4d25cc908c0 100755
--- a/python/packaging/classic/setup.py
+++ b/python/packaging/classic/setup.py
@@ -291,6 +291,7 @@ def run(self):
"pyspark.sql.plot",
"pyspark.sql.protobuf",
"pyspark.sql.streaming",
+ "pyspark.sql.streaming.proto",
"pyspark.sql.worker",
"pyspark.streaming",
"pyspark.bin",
diff --git a/python/packaging/client/setup.py b/python/packaging/client/setup.py
index fd4beba29b76d..30392bcada4cb 100755
--- a/python/packaging/client/setup.py
+++ b/python/packaging/client/setup.py
@@ -68,6 +68,7 @@
test_packages = []
if "SPARK_TESTING" in os.environ:
test_packages = [
+ "pyspark.errors.tests.connect",
"pyspark.tests", # for Memory profiler parity tests
"pyspark.resource.tests",
"pyspark.sql.tests",
@@ -79,6 +80,7 @@
"pyspark.sql.tests.connect.pandas",
"pyspark.sql.tests.connect.shell",
"pyspark.sql.tests.pandas",
+ "pyspark.sql.tests.pandas.helper",
"pyspark.sql.tests.plot",
"pyspark.sql.tests.streaming",
"pyspark.ml.tests",
@@ -168,6 +170,7 @@
"pyspark.sql.plot",
"pyspark.sql.protobuf",
"pyspark.sql.streaming",
+ "pyspark.sql.streaming.proto",
"pyspark.sql.worker",
"pyspark.streaming",
"pyspark.pandas",
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 93a64d8eef10a..59f7856688ee9 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -15,12 +15,13 @@
# limitations under the License.
#
+import os
import sys
import select
import struct
-import socketserver as SocketServer
+import socketserver
import threading
-from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, TypeVar, Union
+from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, TypeVar, Union, Optional
from pyspark.serializers import read_int, CPickleSerializer
from pyspark.errors import PySparkRuntimeError
@@ -252,7 +253,7 @@ def addInPlace(self, value1: U, value2: U) -> U:
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) # type: ignore[type-var]
-class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
+class UpdateRequestHandler(socketserver.StreamRequestHandler):
"""
This handler will keep polling updates from the same socket until the
@@ -293,37 +294,64 @@ def authenticate_and_accum_updates() -> bool:
"The value of the provided token to the AccumulatorServer is not correct."
)
- # first we keep polling till we've received the authentication token
- poll(authenticate_and_accum_updates)
+ # Unix Domain Socket does not need the auth.
+ if auth_token is not None:
+ # first we keep polling till we've received the authentication token
+ poll(authenticate_and_accum_updates)
+
# now we've authenticated, don't need to check for the token anymore
poll(accum_updates)
-class AccumulatorServer(SocketServer.TCPServer):
+class AccumulatorTCPServer(socketserver.TCPServer):
+ server_shutdown = False
+
def __init__(
self,
server_address: Tuple[str, int],
RequestHandlerClass: Type["socketserver.BaseRequestHandler"],
auth_token: str,
):
- SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass)
+ super().__init__(server_address, RequestHandlerClass)
self.auth_token = auth_token
- """
- A simple TCP server that intercepts shutdown() in order to interrupt
- our continuous polling on the handler.
- """
+ def shutdown(self) -> None:
+ self.server_shutdown = True
+ super().shutdown()
+ self.server_close()
+
+
+class AccumulatorUnixServer(socketserver.UnixStreamServer):
server_shutdown = False
+ def __init__(
+ self, socket_path: str, RequestHandlerClass: Type[socketserver.BaseRequestHandler]
+ ):
+ super().__init__(socket_path, RequestHandlerClass)
+ self.auth_token = None
+
def shutdown(self) -> None:
self.server_shutdown = True
- SocketServer.TCPServer.shutdown(self)
+ super().shutdown()
self.server_close()
+ if os.path.exists(self.server_address): # type: ignore[arg-type]
+ os.remove(self.server_address) # type: ignore[arg-type]
+
+
+def _start_update_server(
+ auth_token: str, is_unix_domain_sock: bool, socket_path: Optional[str] = None
+) -> Union[AccumulatorTCPServer, AccumulatorUnixServer]:
+ """Start a TCP or Unix Domain Socket server for accumulator updates."""
+ if is_unix_domain_sock:
+ assert socket_path is not None
+ if os.path.exists(socket_path):
+ os.remove(socket_path)
+ server = AccumulatorUnixServer(socket_path, UpdateRequestHandler)
+ else:
+ server = AccumulatorTCPServer(
+ ("localhost", 0), UpdateRequestHandler, auth_token
+ ) # type: ignore[assignment]
-
-def _start_update_server(auth_token: str) -> AccumulatorServer:
- """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
- server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
diff --git a/python/pyspark/core/broadcast.py b/python/pyspark/core/broadcast.py
index 69d57c35614d3..2d5658284be88 100644
--- a/python/pyspark/core/broadcast.py
+++ b/python/pyspark/core/broadcast.py
@@ -125,8 +125,8 @@ def __init__( # type: ignore[misc]
if sc._encryption_enabled:
# with encryption, we ask the jvm to do the encryption for us, we send it data
# over a socket
- port, auth_secret = self._python_broadcast.setupEncryptionServer()
- (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
+ conn_info, auth_secret = self._python_broadcast.setupEncryptionServer()
+ (encryption_sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
broadcast_out = ChunkedStream(encryption_sock_file, 8192)
else:
# no encryption, we can just write pickled data directly to the file from python
@@ -270,8 +270,8 @@ def value(self) -> T:
# we only need to decrypt it here when encryption is enabled and
# if its on the driver, since executor decryption is handled already
if self._sc is not None and self._sc._encryption_enabled:
- port, auth_secret = self._python_broadcast.setupDecryptionServer()
- (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret)
+ conn_info, auth_secret = self._python_broadcast.setupDecryptionServer()
+ (decrypted_sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
self._python_broadcast.waitTillBroadcastDataSent()
return self.load(decrypted_sock_file)
else:
diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 5fcd4ffb09210..f4d3bbcf8f5b9 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import uuid
import os
import shutil
import signal
@@ -305,11 +306,29 @@ def _do_init(
# they will be passed back to us through a TCP server
assert self._gateway is not None
auth_token = self._gateway.gateway_parameters.auth_token
+ is_unix_domain_sock = (
+ self._conf.get(
+ "spark.python.unix.domain.socket.enabled",
+ os.environ.get("PYSPARK_UDS_MODE", "false"),
+ ).lower()
+ == "true"
+ )
+ socket_path = None
+ if is_unix_domain_sock:
+ socket_dir = self._conf.get("spark.python.unix.domain.socket.dir")
+ if socket_dir is None:
+ socket_dir = getattr(self._jvm, "java.lang.System").getProperty("java.io.tmpdir")
+ socket_path = os.path.join(socket_dir, f".{uuid.uuid4()}.sock")
start_update_server = accumulators._start_update_server
- self._accumulatorServer = start_update_server(auth_token)
- (host, port) = self._accumulatorServer.server_address
+ self._accumulatorServer = start_update_server(auth_token, is_unix_domain_sock, socket_path)
assert self._jvm is not None
- self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
+ if is_unix_domain_sock:
+ self._javaAccumulator = self._jvm.PythonAccumulatorV2(
+ self._accumulatorServer.server_address
+ )
+ else:
+ (host, port) = self._accumulatorServer.server_address # type: ignore[misc]
+ self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
self._jsc.sc().register(self._javaAccumulator)
# If encryption is enabled, we need to setup a server in the jvm to read broadcast
@@ -880,7 +899,7 @@ def _serialize_to_jvm(
if self._encryption_enabled:
# with encryption, we open a server in java and send the data directly
server = server_func()
- (sock_file, _) = local_connect_and_auth(server.port(), server.secret())
+ (sock_file, _) = local_connect_and_auth(server.connInfo(), server.secret())
chunked_out = ChunkedStream(sock_file, 8192)
serializer.dump_stream(data, chunked_out)
chunked_out.close()
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index a23af109ea6de..ca33ce2c39ef7 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import uuid
import numbers
import os
import signal
@@ -93,8 +93,20 @@ def manager():
# Create a new process group to corral our children
os.setpgid(0, 0)
+ is_unix_domain_sock = os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED", "false").lower() == "true"
+ socket_path = None
+
# Create a listening socket on the loopback interface
- if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
+ if is_unix_domain_sock:
+ assert "PYTHON_WORKER_FACTORY_SOCK_DIR" in os.environ
+ socket_path = os.path.join(
+ os.environ["PYTHON_WORKER_FACTORY_SOCK_DIR"], f".{uuid.uuid4()}.sock"
+ )
+ listen_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ listen_sock.bind(socket_path)
+ listen_sock.listen(max(1024, SOMAXCONN))
+ listen_port = socket_path
+ elif os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
listen_sock = socket.socket(AF_INET6, SOCK_STREAM)
listen_sock.bind(("::1", 0, 0, 0))
listen_sock.listen(max(1024, SOMAXCONN))
@@ -108,10 +120,15 @@ def manager():
# re-open stdin/stdout in 'wb' mode
stdin_bin = os.fdopen(sys.stdin.fileno(), "rb", 4)
stdout_bin = os.fdopen(sys.stdout.fileno(), "wb", 4)
- write_int(listen_port, stdout_bin)
+ if is_unix_domain_sock:
+ write_with_length(listen_port.encode("utf-8"), stdout_bin)
+ else:
+ write_int(listen_port, stdout_bin)
stdout_bin.flush()
def shutdown(code):
+ if socket_path is not None and os.path.exists(socket_path):
+ os.remove(socket_path)
signal.signal(SIGTERM, SIG_DFL)
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
@@ -195,7 +212,10 @@ def handle_sigterm(*args):
write_int(os.getpid(), outfile)
outfile.flush()
outfile.close()
- authenticated = False
+ authenticated = (
+ os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED", "false").lower() == "true"
+ or False
+ )
while True:
code = worker(sock, authenticated)
if code == 0:
diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json
index 33730f757dbd9..4e2727a87585d 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -159,8 +159,8 @@
"Calling property or member '' is not supported in PySpark Classic, please use Spark Connect instead."
]
},
- "COLLATION_INVALID_PROVIDER" : {
- "message" : [
+ "COLLATION_INVALID_PROVIDER": {
+ "message": [
"The value does not represent a correct collation provider. Supported providers are: []."
]
},
@@ -189,11 +189,21 @@
"Remote client cannot create a SparkContext. Create SparkSession instead."
]
},
+ "DATA_SOURCE_EXTRANEOUS_FILTERS": {
+ "message": [
+ ".pushFilters() returned filters that are not part of the input. Make sure that each returned filter is one of the input filters by reference."
+ ]
+ },
"DATA_SOURCE_INVALID_RETURN_TYPE": {
"message": [
"Unsupported return type ('') from Python data source ''. Expected types: ."
]
},
+ "DATA_SOURCE_PUSHDOWN_DISABLED": {
+ "message": [
+ " implements pushFilters() but filter pushdown is disabled because configuration '' is false. Set it to true to enable filter pushdown."
+ ]
+ },
"DATA_SOURCE_RETURN_SCHEMA_MISMATCH": {
"message": [
"Return schema mismatch in the result from 'read' method. Expected: columns, Found: columns. Make sure the returned values match the required output schema."
@@ -204,6 +214,11 @@
"Expected , but got ."
]
},
+ "DATA_SOURCE_UNSUPPORTED_FILTER": {
+ "message": [
+ "Unexpected filter ."
+ ]
+ },
"DIFFERENT_PANDAS_DATAFRAME": {
"message": [
"DataFrames are not almost equal:",
@@ -372,8 +387,8 @@
"All items in `` should be in , got ."
]
},
- "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS" : {
- "message" : [
+ "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS": {
+ "message": [
"Collations can only be applied to string types, but the JSON data type is ."
]
},
@@ -502,8 +517,8 @@
" and should be of the same length, got and ."
]
},
- "MALFORMED_VARIANT" : {
- "message" : [
+ "MALFORMED_VARIANT": {
+ "message": [
"Variant binary is malformed. Please check the data source is valid."
]
},
@@ -517,7 +532,7 @@
"A master URL must be set in your configuration."
]
},
- "MEMORY_PROFILE_INVALID_SOURCE":{
+ "MEMORY_PROFILE_INVALID_SOURCE": {
"message": [
"Memory profiler can only be used on editors with line numbers."
]
@@ -812,7 +827,7 @@
" >= must be installed; however, it was not found."
]
},
- "PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS" : {
+ "PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS": {
"message": [
"The Pandas SCALAR_ITER UDF outputs more rows than input rows."
]
@@ -865,7 +880,7 @@
},
"RESPONSE_ALREADY_RECEIVED": {
"message": [
- "OPERATION_NOT_FOUND on the server but responses were already received from it."
+ " on the server but responses were already received from it."
]
},
"RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDF": {
@@ -905,7 +920,7 @@
},
"SCHEMA_MISMATCH_FOR_PANDAS_UDF": {
"message": [
- "Result vector from pandas_udf was not the required length: expected , got ."
+ "Result vector from was not the required length: expected , got ."
]
},
"SESSION_ALREADY_EXIST": {
diff --git a/python/pyspark/errors/exceptions/__init__.py b/python/pyspark/errors/exceptions/__init__.py
index c66f35958f8dd..2590731648534 100644
--- a/python/pyspark/errors/exceptions/__init__.py
+++ b/python/pyspark/errors/exceptions/__init__.py
@@ -30,3 +30,4 @@ def _write_self() -> None:
sort_keys=True,
indent=2,
)
+ f.write("\n")
diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py
index 13501ba0de785..347f8db318476 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -17,8 +17,9 @@
import warnings
from abc import ABC, abstractmethod
from enum import Enum
-from typing import Dict, Optional, cast, Iterable, TYPE_CHECKING, List
+from typing import Dict, Optional, TypeVar, cast, Iterable, TYPE_CHECKING, List
+from pyspark.errors.exceptions.tblib import Traceback
from pyspark.errors.utils import ErrorClassesReader
from pyspark.logger import PySparkLogger
from pickle import PicklingError
@@ -27,6 +28,9 @@
from pyspark.sql.types import Row
+T = TypeVar("T", bound="PySparkException")
+
+
class PySparkException(Exception):
"""
Base Exception for handling errors generated from PySpark.
@@ -449,3 +453,30 @@ def summary(self) -> str:
Summary of the exception cause.
"""
...
+
+
+def recover_python_exception(e: T) -> T:
+ """
+ Recover Python exception stack trace.
+
+ Many JVM exceptions types may wrap Python exceptions. For example:
+ - UDFs can cause PythonException
+ - UDTFs and Data Sources can cause AnalysisException
+ """
+ python_exception_header = "Traceback (most recent call last):"
+ try:
+ message = str(e)
+ start = message.find(python_exception_header)
+ if start == -1:
+ # No Python exception found
+ return e
+
+ # The message contains a Python exception. Parse it to use it as the exception's traceback.
+ # This allows richer error messages, for example showing line content in Python UDF.
+ python_exception_string = message[start:]
+ tb = Traceback.from_string(python_exception_string)
+ tb.populate_linecache()
+ return e.with_traceback(tb.as_traceback())
+ except BaseException:
+ # Parsing the stacktrace is best effort.
+ return e
diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py
index 8ae4d48541c35..ba5c2601a6a6e 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -37,6 +37,7 @@
UnknownException as BaseUnknownException,
QueryContext as BaseQueryContext,
QueryContextType,
+ recover_python_exception,
)
if TYPE_CHECKING:
@@ -185,6 +186,11 @@ def getQueryContext(self) -> List[BaseQueryContext]:
def convert_exception(e: "Py4JJavaError") -> CapturedException:
+ converted = _convert_exception(e)
+ return recover_python_exception(converted)
+
+
+def _convert_exception(e: "Py4JJavaError") -> CapturedException:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py
index f87494b72426a..fafdd6b84297f 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import pyspark.sql.connect.proto as pb2
+import grpc
import json
+from grpc import StatusCode
from typing import Dict, List, Optional, TYPE_CHECKING
from pyspark.errors.exceptions.base import (
@@ -39,9 +40,11 @@
StreamingPythonRunnerInitializationException as BaseStreamingPythonRunnerInitException,
PickleException as BasePickleException,
UnknownException as BaseUnknownException,
+ recover_python_exception,
)
if TYPE_CHECKING:
+ import pyspark.sql.connect.proto as pb2
from google.rpc.error_details_pb2 import ErrorInfo
@@ -54,9 +57,25 @@ class SparkConnectException(PySparkException):
def convert_exception(
info: "ErrorInfo",
truncated_message: str,
- resp: Optional[pb2.FetchErrorDetailsResponse],
+ resp: Optional["pb2.FetchErrorDetailsResponse"],
display_server_stacktrace: bool = False,
+ grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
) -> SparkConnectException:
+ converted = _convert_exception(
+ info, truncated_message, resp, display_server_stacktrace, grpc_status_code
+ )
+ return recover_python_exception(converted)
+
+
+def _convert_exception(
+ info: "ErrorInfo",
+ truncated_message: str,
+ resp: Optional["pb2.FetchErrorDetailsResponse"],
+ display_server_stacktrace: bool = False,
+ grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
+) -> SparkConnectException:
+ import pyspark.sql.connect.proto as pb2
+
raw_classes = info.metadata.get("classes")
classes: List[str] = json.loads(raw_classes) if raw_classes else []
sql_state = info.metadata.get("sqlState")
@@ -89,8 +108,9 @@ def convert_exception(
if "org.apache.spark.api.python.PythonException" in classes:
return PythonException(
- "\n An exception was thrown from the Python worker. "
- "Please see the stack trace below.\n%s" % message
+ message="\n An exception was thrown from the Python worker. "
+ "Please see the stack trace below.\n%s" % message,
+ grpc_status_code=grpc_status_code,
)
# Return exception based on class mapping
@@ -113,6 +133,7 @@ def convert_exception(
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
contexts=contexts,
+ grpc_status_code=grpc_status_code,
)
# Return UnknownException if there is no matched exception class
@@ -125,16 +146,17 @@ def convert_exception(
server_stacktrace=stacktrace,
display_server_stacktrace=display_server_stacktrace,
contexts=contexts,
+ grpc_status_code=grpc_status_code,
)
-def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str:
+def _extract_jvm_stacktrace(resp: "pb2.FetchErrorDetailsResponse") -> str:
if len(resp.errors[resp.root_error_idx].stack_trace) == 0:
return ""
lines: List[str] = []
- def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None:
+ def format_stacktrace(error: "pb2.FetchErrorDetailsResponse.Error") -> None:
message = f"{error.error_type_hierarchy[0]}: {error.message}"
if len(lines) == 0:
lines.append(error.error_type_hierarchy[0])
@@ -170,6 +192,7 @@ def __init__(
server_stacktrace: Optional[str] = None,
display_server_stacktrace: bool = False,
contexts: Optional[List[BaseQueryContext]] = None,
+ grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
) -> None:
if contexts is None:
contexts = []
@@ -197,6 +220,7 @@ def __init__(
self._stacktrace: Optional[str] = server_stacktrace
self._display_stacktrace: bool = display_server_stacktrace
self._contexts: List[BaseQueryContext] = contexts
+ self._grpc_status_code = grpc_status_code
self._log_exception()
def getSqlState(self) -> Optional[str]:
@@ -214,6 +238,9 @@ def getMessage(self) -> str:
desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
return desc
+ def getGrpcStatusCode(self) -> grpc.StatusCode:
+ return self._grpc_status_code
+
def __str__(self) -> str:
return self.getMessage()
@@ -235,6 +262,7 @@ def __init__(
server_stacktrace: Optional[str] = None,
display_server_stacktrace: bool = False,
contexts: Optional[List[BaseQueryContext]] = None,
+ grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
) -> None:
super().__init__(
message=message,
@@ -245,6 +273,7 @@ def __init__(
server_stacktrace=server_stacktrace,
display_server_stacktrace=display_server_stacktrace,
contexts=contexts,
+ grpc_status_code=grpc_status_code,
)
@@ -393,7 +422,7 @@ class PickleException(SparkConnectGrpcException, BasePickleException):
class SQLQueryContext(BaseQueryContext):
- def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
+ def __init__(self, q: "pb2.FetchErrorDetailsResponse.QueryContext"):
self._q = q
def contextType(self) -> QueryContextType:
@@ -430,7 +459,7 @@ def summary(self) -> str:
class DataFrameQueryContext(BaseQueryContext):
- def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
+ def __init__(self, q: "pb2.FetchErrorDetailsResponse.QueryContext"):
self._q = q
def contextType(self) -> QueryContextType:
diff --git a/python/pyspark/errors/exceptions/tblib.py b/python/pyspark/errors/exceptions/tblib.py
new file mode 100644
index 0000000000000..b444f0fb45d32
--- /dev/null
+++ b/python/pyspark/errors/exceptions/tblib.py
@@ -0,0 +1,342 @@
+"""
+Class for parsing Python tracebacks.
+
+This module was adapted from the `tblib` package https://github.com/ionelmc/python-tblib
+modified to also recover line content from the traceback.
+
+BSD 2-Clause License
+
+Copyright (c) 2013-2023, Ionel Cristian Mărieș. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are
+permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this list of
+ conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice, this list
+ of conditions and the following disclaimer in the documentation and/or other materials
+ provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
+EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
+THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
+OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
+TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import re
+import sys
+from types import CodeType, FrameType, TracebackType
+from typing import Any, Dict, List, Optional
+
+__version__ = "3.0.0"
+__all__ = "Traceback", "TracebackParseError", "Frame", "Code"
+
+FRAME_RE = re.compile(
+ r'^\s*File "(?P.+)", line (?P\d+)(, in (?P.+))?$'
+)
+
+
+class _AttrDict(dict):
+ __slots__ = ()
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name) from None
+
+
+# noinspection PyPep8Naming
+class __traceback_maker(Exception):
+ pass
+
+
+class TracebackParseError(Exception):
+ pass
+
+
+class Code:
+ """
+ Class that replicates just enough of the builtin Code object to enable serialization
+ and traceback rendering.
+ """
+
+ co_code: Optional[bytes] = None
+
+ def __init__(self, code: CodeType) -> None:
+ self.co_filename = code.co_filename
+ self.co_name: Optional[str] = code.co_name
+ self.co_argcount = 0
+ self.co_kwonlyargcount = 0
+ self.co_varnames = ()
+ self.co_nlocals = 0
+ self.co_stacksize = 0
+ self.co_flags = 64
+ self.co_firstlineno = 0
+
+
+class Frame:
+ """
+ Class that replicates just enough of the builtin Frame object to enable serialization
+ and traceback rendering.
+
+ Args:
+
+ get_locals (callable): A function that take a frame argument and returns a dict.
+
+ See :class:`Traceback` class for example.
+ """
+
+ def __init__(self, frame: FrameType, *, get_locals: Any = None) -> None:
+ self.f_locals = {} if get_locals is None else get_locals(frame)
+ self.f_globals = {k: v for k, v in frame.f_globals.items() if k in ("__file__", "__name__")}
+ self.f_code = Code(frame.f_code)
+ self.f_lineno = frame.f_lineno
+
+ def clear(self) -> None:
+ """
+ For compatibility with PyPy 3.5;
+ clear() was added to frame in Python 3.4
+ and is called by traceback.clear_frames(), which
+ in turn is called by unittest.TestCase.assertRaises
+ """
+
+
+class LineCacheEntry(list):
+ """
+ The list of lines in a file where only some of the lines are available.
+ """
+
+ def set_line(self, lineno: int, line: str) -> None:
+ self.extend([""] * (lineno - len(self)))
+ self[lineno - 1] = line
+
+
+class Traceback:
+ """
+ Class that wraps builtin Traceback objects.
+
+ Args:
+ get_locals (callable): A function that take a frame argument and returns a dict.
+
+ Ideally you will only return exactly what you need, and only with simple types
+ that can be json serializable.
+
+ Example:
+
+ .. code:: python
+
+ def get_locals(frame):
+ if frame.f_locals.get("__tracebackhide__"):
+ return {"__tracebackhide__": True}
+ else:
+ return {}
+ """
+
+ tb_next: Optional["Traceback"] = None
+
+ def __init__(self, tb: TracebackType, *, get_locals: Any = None):
+ self.tb_frame = Frame(tb.tb_frame, get_locals=get_locals)
+ self.tb_lineno = int(tb.tb_lineno)
+ self.cached_lines: Dict[str, Dict[int, str]] = {} # filename -> lineno -> line
+ """
+ Lines shown in the parsed traceback.
+ """
+
+ # Build in place to avoid exceeding the recursion limit
+ _tb = tb.tb_next
+ prev_traceback = self
+ cls = type(self)
+ while _tb is not None:
+ traceback = object.__new__(cls)
+ traceback.tb_frame = Frame(_tb.tb_frame, get_locals=get_locals)
+ traceback.tb_lineno = int(_tb.tb_lineno)
+ prev_traceback.tb_next = traceback
+ prev_traceback = traceback
+ _tb = _tb.tb_next
+
+ def populate_linecache(self) -> None:
+ """
+ For each cached line, update the linecache if the file is not present.
+ This helps us show the original lines even if the source file is not available,
+ for example when the parsed traceback comes from a different host.
+ """
+ import linecache
+
+ for filename, lines in self.cached_lines.items():
+ entry: list[str] = linecache.getlines(filename, module_globals=None)
+ if entry:
+ if not isinstance(entry, LineCacheEntry):
+ # no need to update the cache if the file is present
+ continue
+ else:
+ entry = LineCacheEntry()
+ linecache.cache[filename] = (1, None, entry, filename)
+ for lineno, line in lines.items():
+ entry.set_line(lineno, line)
+
+ def as_traceback(self) -> Optional[TracebackType]:
+ """
+ Convert to a builtin Traceback object that is usable for raising or rendering a stacktrace.
+ """
+ current: Optional[Traceback] = self
+ top_tb = None
+ tb = None
+ stub = compile(
+ "raise __traceback_maker",
+ "",
+ "exec",
+ )
+ while current:
+ f_code = current.tb_frame.f_code
+ code = stub.replace(
+ co_firstlineno=current.tb_lineno,
+ co_argcount=0,
+ co_filename=f_code.co_filename,
+ co_name=f_code.co_name or stub.co_name,
+ co_freevars=(),
+ co_cellvars=(),
+ )
+
+ # noinspection PyBroadException
+ try:
+ exec(
+ code, dict(current.tb_frame.f_globals), dict(current.tb_frame.f_locals)
+ ) # noqa: S102
+ except Exception:
+ next_tb = sys.exc_info()[2].tb_next # type: ignore
+ if top_tb is None:
+ top_tb = next_tb
+ if tb is not None:
+ tb.tb_next = next_tb
+ tb = next_tb
+ del next_tb
+
+ current = current.tb_next
+ try:
+ return top_tb
+ finally:
+ del top_tb
+ del tb
+
+ to_traceback = as_traceback
+
+ def as_dict(self) -> dict:
+ """
+ Converts to a dictionary representation. You can serialize the result to JSON
+ as it only has builtin objects like dicts, lists, ints or strings.
+ """
+ if self.tb_next is None:
+ tb_next = None
+ else:
+ tb_next = self.tb_next.as_dict()
+
+ code = {
+ "co_filename": self.tb_frame.f_code.co_filename,
+ "co_name": self.tb_frame.f_code.co_name,
+ }
+ frame = {
+ "f_globals": self.tb_frame.f_globals,
+ "f_locals": self.tb_frame.f_locals,
+ "f_code": code,
+ "f_lineno": self.tb_frame.f_lineno,
+ }
+ return {
+ "tb_frame": frame,
+ "tb_lineno": self.tb_lineno,
+ "tb_next": tb_next,
+ }
+
+ to_dict = as_dict
+
+ @classmethod
+ def from_dict(cls, dct: dict) -> "Traceback":
+ """
+ Creates an instance from a dictionary with the same structure as ``.as_dict()`` returns.
+ """
+ if dct["tb_next"]:
+ tb_next = cls.from_dict(dct["tb_next"])
+ else:
+ tb_next = None
+
+ code = _AttrDict(
+ co_filename=dct["tb_frame"]["f_code"]["co_filename"],
+ co_name=dct["tb_frame"]["f_code"]["co_name"],
+ )
+ frame = _AttrDict(
+ f_globals=dct["tb_frame"]["f_globals"],
+ f_locals=dct["tb_frame"].get("f_locals", {}),
+ f_code=code,
+ f_lineno=dct["tb_frame"]["f_lineno"],
+ )
+ tb = _AttrDict(
+ tb_frame=frame,
+ tb_lineno=dct["tb_lineno"],
+ tb_next=tb_next,
+ )
+ return cls(tb, get_locals=get_all_locals) # type: ignore
+
+ @classmethod
+ def from_string(cls, string: str, strict: bool = True) -> "Traceback":
+ """
+ Creates an instance by parsing a stacktrace.
+ Strict means that parsing stops when lines are not indented by at least two spaces anymore.
+ """
+
+ frames: List[Dict[str, str]] = []
+ cached_lines: Dict[str, Dict[int, str]] = {}
+
+ lines = string.splitlines()[::-1]
+ if strict: # skip the header
+ while lines:
+ line = lines.pop()
+ if line == "Traceback (most recent call last):":
+ break
+
+ while lines:
+ line = lines.pop()
+ frame_match = FRAME_RE.match(line)
+ if frame_match:
+ frames.append(frame_match.groupdict())
+ if lines and lines[-1].startswith(" "): # code for the frame
+ code = lines.pop().strip()
+ filename = frame_match.group("co_filename")
+ lineno = int(frame_match.group("tb_lineno"))
+ cached_lines.setdefault(filename, {}).setdefault(lineno, code)
+ elif line.startswith(" "):
+ pass
+ elif strict:
+ break # traceback ended
+
+ if frames:
+ previous = None
+ for frame in reversed(frames):
+ previous = _AttrDict(
+ frame,
+ tb_frame=_AttrDict(
+ frame,
+ f_globals=_AttrDict(
+ __file__=frame["co_filename"],
+ __name__="?",
+ ),
+ f_locals={},
+ f_code=_AttrDict(frame),
+ f_lineno=int(frame["tb_lineno"]),
+ ),
+ tb_next=previous,
+ )
+ self = cls(previous) # type: ignore
+ self.cached_lines = cached_lines
+ return self
+ else:
+ raise TracebackParseError("Could not find any frames in %r." % string)
+
+
+def get_all_locals(frame: FrameType) -> dict:
+ return dict(frame.f_locals)
diff --git a/python/pyspark/errors/tests/connect/__init__.py b/python/pyspark/errors/tests/connect/__init__.py
new file mode 100644
index 0000000000000..cce3acad34a49
--- /dev/null
+++ b/python/pyspark/errors/tests/connect/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/errors/tests/connect/test_parity_traceback.py b/python/pyspark/errors/tests/connect/test_parity_traceback.py
new file mode 100644
index 0000000000000..c8ef9fcefae7a
--- /dev/null
+++ b/python/pyspark/errors/tests/connect/test_parity_traceback.py
@@ -0,0 +1,35 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.errors.tests.test_traceback import BaseTracebackSqlTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class TracebackSqlConnectTests(BaseTracebackSqlTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.errors.tests.connect.test_parity_traceback import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/errors/tests/test_connect_errors_conversion.py b/python/pyspark/errors/tests/test_connect_errors_conversion.py
index a6ed5e7d391ee..344af2ad62331 100644
--- a/python/pyspark/errors/tests/test_connect_errors_conversion.py
+++ b/python/pyspark/errors/tests/test_connect_errors_conversion.py
@@ -17,20 +17,26 @@
#
import unittest
-from pyspark.errors.exceptions.connect import (
- convert_exception,
- EXCEPTION_CLASS_MAPPING,
- SparkConnectGrpcException,
- PythonException,
- AnalysisException,
-)
-from pyspark.sql.connect.proto import FetchErrorDetailsResponse as pb2
-from google.rpc.error_details_pb2 import ErrorInfo
+from pyspark.testing import should_test_connect, connect_requirement_message
+if should_test_connect:
+ from pyspark.errors.exceptions.connect import (
+ convert_exception,
+ EXCEPTION_CLASS_MAPPING,
+ SparkConnectGrpcException,
+ PythonException,
+ AnalysisException,
+ )
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ConnectErrorsTest(unittest.TestCase):
def test_convert_exception_known_class(self):
# Mock ErrorInfo with a known error class
+ from google.rpc.error_details_pb2 import ErrorInfo
+ from grpc import StatusCode
+
info = {
"reason": "org.apache.spark.sql.AnalysisException",
"metadata": {
@@ -42,16 +48,23 @@ def test_convert_exception_known_class(self):
}
truncated_message = "Analysis error occurred"
exception = convert_exception(
- info=ErrorInfo(**info), truncated_message=truncated_message, resp=None
+ info=ErrorInfo(**info),
+ truncated_message=truncated_message,
+ resp=None,
+ grpc_status_code=StatusCode.INTERNAL,
)
self.assertIsInstance(exception, AnalysisException)
self.assertEqual(exception.getSqlState(), "42000")
self.assertEqual(exception._errorClass, "ANALYSIS.ERROR")
self.assertEqual(exception._messageParameters, {"param1": "value1"})
+ self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL)
def test_convert_exception_python_exception(self):
# Mock ErrorInfo for PythonException
+ from google.rpc.error_details_pb2 import ErrorInfo
+ from grpc import StatusCode
+
info = {
"reason": "org.apache.spark.api.python.PythonException",
"metadata": {
@@ -60,27 +73,38 @@ def test_convert_exception_python_exception(self):
}
truncated_message = "Python worker error occurred"
exception = convert_exception(
- info=ErrorInfo(**info), truncated_message=truncated_message, resp=None
+ info=ErrorInfo(**info),
+ truncated_message=truncated_message,
+ resp=None,
+ grpc_status_code=StatusCode.INTERNAL,
)
self.assertIsInstance(exception, PythonException)
self.assertIn("An exception was thrown from the Python worker", exception.getMessage())
+ self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL)
def test_convert_exception_unknown_class(self):
# Mock ErrorInfo with an unknown error class
+ from google.rpc.error_details_pb2 import ErrorInfo
+ from grpc import StatusCode
+
info = {
"reason": "org.apache.spark.UnknownException",
"metadata": {"classes": '["org.apache.spark.UnknownException"]'},
}
truncated_message = "Unknown error occurred"
exception = convert_exception(
- info=ErrorInfo(**info), truncated_message=truncated_message, resp=None
+ info=ErrorInfo(**info),
+ truncated_message=truncated_message,
+ resp=None,
+ grpc_status_code=StatusCode.INTERNAL,
)
self.assertIsInstance(exception, SparkConnectGrpcException)
self.assertEqual(
exception.getMessage(), "(org.apache.spark.UnknownException) Unknown error occurred"
)
+ self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL)
def test_exception_class_mapping(self):
# Ensure that all keys in EXCEPTION_CLASS_MAPPING are valid
@@ -92,6 +116,9 @@ def test_exception_class_mapping(self):
def test_convert_exception_with_stacktrace(self):
# Mock FetchErrorDetailsResponse with stacktrace
+ from google.rpc.error_details_pb2 import ErrorInfo
+ from pyspark.sql.connect.proto import FetchErrorDetailsResponse as pb2
+
resp = pb2(
root_error_idx=0,
errors=[
@@ -132,7 +159,10 @@ def test_convert_exception_with_stacktrace(self):
}
truncated_message = "Root error message"
exception = convert_exception(
- info=ErrorInfo(**info), truncated_message=truncated_message, resp=resp
+ info=ErrorInfo(**info),
+ truncated_message=truncated_message,
+ resp=resp,
+ display_server_stacktrace=True,
)
self.assertIsInstance(exception, SparkConnectGrpcException)
@@ -141,6 +171,9 @@ def test_convert_exception_with_stacktrace(self):
def test_convert_exception_fallback(self):
# Mock ErrorInfo with missing class information
+ from google.rpc.error_details_pb2 import ErrorInfo
+ from grpc import StatusCode
+
info = {
"reason": "org.apache.spark.UnknownReason",
"metadata": {},
@@ -154,6 +187,7 @@ def test_convert_exception_fallback(self):
self.assertEqual(
exception.getMessage(), "(org.apache.spark.UnknownReason) Fallback error occurred"
)
+ self.assertEqual(exception.getGrpcStatusCode(), StatusCode.UNKNOWN)
if __name__ == "__main__":
diff --git a/python/pyspark/errors/tests/test_traceback.py b/python/pyspark/errors/tests/test_traceback.py
new file mode 100644
index 0000000000000..92955ad7f98df
--- /dev/null
+++ b/python/pyspark/errors/tests/test_traceback.py
@@ -0,0 +1,283 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import importlib.util
+import linecache
+import os
+import re
+import sys
+import tempfile
+import traceback
+import unittest
+
+import pyspark.sql.functions as sf
+from pyspark.errors import PythonException
+from pyspark.errors.exceptions.base import AnalysisException
+from pyspark.errors.exceptions.tblib import Traceback
+from pyspark.sql.datasource import DataSource, DataSourceReader
+from pyspark.sql.session import SparkSession
+from pyspark.testing.sqlutils import (
+ ReusedSQLTestCase,
+ have_pandas,
+ have_pyarrow,
+ pandas_requirement_message,
+ pyarrow_requirement_message,
+)
+
+
+class TracebackTests(unittest.TestCase):
+ """Tests for the Traceback class."""
+
+ def make_traceback(self):
+ try:
+ raise ValueError("bar")
+ except ValueError:
+ _, _, tb = sys.exc_info()
+ return traceback.format_exc(), "".join(traceback.format_tb(tb))
+
+ def make_traceback_with_temp_file(self, code="def foo(): 1 / 0", filename=""):
+ with tempfile.NamedTemporaryFile("w", suffix=f"{filename}.py", delete=True) as f:
+ f.write(code)
+ f.flush()
+ spec = importlib.util.spec_from_file_location("foo", f.name)
+ foo = importlib.util.module_from_spec(spec)
+ try:
+ spec.loader.exec_module(foo)
+ foo.foo()
+ except Exception:
+ _, _, tb = sys.exc_info()
+ return traceback.format_exc(), "".join(traceback.format_tb(tb))
+ else:
+ self.fail("Error not raised")
+
+ def assert_traceback(self, tb: Traceback, expected: str):
+ def remove_positions(s):
+ # For example, remove the 2nd line in this traceback:
+ """
+ result = (x / y / z) * (a / b / c)
+ ~~~~~~^~~
+ """
+ pattern = r"\s*[\^\~]+\s*\n"
+ return re.sub(pattern, "\n", s)
+
+ tb.populate_linecache()
+ actual = remove_positions("".join(traceback.format_tb(tb.as_traceback())))
+ expected = remove_positions(expected)
+ self.assertEqual(actual, expected)
+
+ def test_simple(self):
+ s, expected = self.make_traceback()
+ self.assert_traceback(Traceback.from_string(s), expected)
+
+ def test_missing_source(self):
+ s, expected = self.make_traceback_with_temp_file()
+ linecache.clearcache() # remove temp file from cache
+ self.assert_traceback(Traceback.from_string(s), expected)
+
+ def test_recursion(self):
+ """
+ Don't parse [Previous line repeated n times] because it's expensive for large n.
+ Since the input string is not necessarily a Python traceback, Traceback should keep runtime
+ linear to the input string length to be safe from malicious inputs.
+ """
+
+ def foo(depth):
+ if depth > 0:
+ return foo(depth - 1)
+ raise 1 / 0
+
+ try:
+ foo(100)
+ except ZeroDivisionError:
+ s = traceback.format_exc()
+ actual = "".join(traceback.format_tb(Traceback.from_string(s).as_traceback()))
+ self.assertIn("[Previous line repeated", s)
+ self.assertNotIn("[Previous line repeated", actual)
+
+ @unittest.skipIf(
+ os.name != "posix",
+ "These file names may be invalid on non-posix systems",
+ )
+ def test_filename(self):
+ for filename in [
+ "",
+ " ",
+ "\\",
+ '"',
+ "'",
+ '", line 1, in hello',
+ ]:
+ with self.subTest(filename=filename):
+ s, expected = self.make_traceback_with_temp_file(filename=filename)
+ linecache.clearcache()
+ self.assert_traceback(Traceback.from_string(s), expected)
+
+ @unittest.skipIf(
+ os.name != "posix",
+ "These file names may be invalid on non-posix systems",
+ )
+ def test_filename_failure_newline(self):
+ # tblib can't handle newline in the filename
+ s, expected = self.make_traceback_with_temp_file(filename="\n")
+ linecache.clearcache()
+ tb = Traceback.from_string(s)
+ tb.populate_linecache()
+ actual = "".join(traceback.format_tb(tb.as_traceback()))
+ self.assertNotEqual(actual, expected)
+
+ def test_syntax_error(self):
+ bad_syntax = "bad syntax"
+ s, _ = self.make_traceback_with_temp_file(bad_syntax)
+ tb = Traceback.from_string(s)
+ tb.populate_linecache()
+ actual = "".join(traceback.format_tb(tb.as_traceback()))
+ self.assertIn("bad syntax", actual)
+
+
+@unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message,
+)
+class BaseTracebackSqlTestsMixin:
+ """Tests for recovering the original traceback from JVM exceptions."""
+
+ spark: SparkSession
+
+ @staticmethod
+ def raise_exception():
+ raise ValueError("bar")
+
+ def assertInOnce(self, needle: str, haystack: str):
+ """Assert that a string appears only once in another string."""
+ count = haystack.count(needle)
+ self.assertEqual(count, 1, f"{needle} appears more than once in {haystack}")
+
+ def test_udf(self):
+ for jvm_stack_trace in [False, True]:
+ with self.subTest(jvm_stack_trace=jvm_stack_trace), self.sql_conf(
+ {"spark.sql.pyspark.jvmStacktrace.enabled": jvm_stack_trace}
+ ):
+
+ @sf.udf()
+ def foo():
+ raise ValueError("bar")
+
+ df = self.spark.range(1).select(foo())
+ try:
+ df.show()
+ except PythonException:
+ _, _, tb = sys.exc_info()
+ else:
+ self.fail("PythonException not raised")
+
+ s = "".join(traceback.format_tb(tb))
+ self.assertInOnce("""df.show()""", s)
+ self.assertInOnce("""raise ValueError("bar")""", s)
+
+ def test_datasource_analysis(self):
+ class MyDataSource(DataSource):
+ def schema(self):
+ raise ValueError("bar")
+
+ self.spark.dataSource.register(MyDataSource)
+ try:
+ self.spark.read.format("MyDataSource").load().show()
+ except AnalysisException:
+ _, _, tb = sys.exc_info()
+ else:
+ self.fail("AnalysisException not raised")
+
+ s = "".join(traceback.format_tb(tb))
+ self.assertInOnce("""self.spark.read.format("MyDataSource").load().show()""", s)
+ self.assertInOnce("""raise ValueError("bar")""", s)
+
+ def test_datasource_execution(self):
+ class MyDataSource(DataSource):
+ def schema(self):
+ return "x int"
+
+ def reader(self, schema):
+ return MyDataSourceReader()
+
+ class MyDataSourceReader(DataSourceReader):
+ def read(self, partitions):
+ raise ValueError("bar")
+
+ self.spark.dataSource.register(MyDataSource)
+ try:
+ self.spark.read.format("MyDataSource").load().show()
+ except PythonException:
+ _, _, tb = sys.exc_info()
+ else:
+ self.fail("PythonException not raised")
+
+ s = "".join(traceback.format_tb(tb))
+ self.assertInOnce("""self.spark.read.format("MyDataSource").load().show()""", s)
+ self.assertInOnce("""raise ValueError("bar")""", s)
+
+ def test_udtf_analysis(self):
+ @sf.udtf()
+ class MyUdtf:
+ @staticmethod
+ def analyze():
+ raise ValueError("bar")
+
+ def eval(self):
+ pass
+
+ try:
+ MyUdtf().show()
+ except AnalysisException:
+ _, _, tb = sys.exc_info()
+ else:
+ self.fail("AnalysisException not raised")
+
+ s = "".join(traceback.format_tb(tb))
+ self.assertInOnce("""MyUdtf().show()""", s)
+ self.assertInOnce("""raise ValueError("bar")""", s)
+
+ def test_udtf_execution(self):
+ @sf.udtf(returnType="x int")
+ class MyUdtf:
+ def eval(self):
+ raise ValueError("bar")
+
+ try:
+ MyUdtf().show()
+ except PythonException:
+ _, _, tb = sys.exc_info()
+ else:
+ self.fail("PythonException not raised")
+
+ s = "".join(traceback.format_tb(tb))
+ self.assertInOnce("""MyUdtf().show()""", s)
+ self.assertInOnce("""raise ValueError("bar")""", s)
+
+
+class TracebackSqlClassicTests(BaseTracebackSqlTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.errors.tests.test_traceback import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 0bdfa27fc7021..224ef34fd5edc 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -328,8 +328,7 @@ def transformSchema(self, schema: StructType) -> StructType:
def _transform(self, dataset: DataFrame) -> DataFrame:
self.transformSchema(dataset.schema)
- # TODO(SPARK-48515): Use Arrow Python UDF
- transformUDF = udf(self.createTransformFunc(), self.outputDataType(), useArrow=False)
+ transformUDF = udf(self.createTransformFunc(), self.outputDataType())
transformedDataset = dataset.withColumn(
self.getOutputCol(), transformUDF(dataset[self.getInputCol()])
)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 21e4ee4f6d0e6..a5fdaed0db2c4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -89,6 +89,7 @@
try_remote_read,
try_remote_write,
try_remote_attribute_relation,
+ _cache_spark_dataset,
)
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
@@ -889,7 +890,10 @@ def summary(self) -> "LinearSVCTrainingSummary": # type: ignore[override]
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
- return LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
+ s = LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
@@ -909,7 +913,10 @@ def evaluate(self, dataset: DataFrame) -> "LinearSVCSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_lsvc_summary = self._call_java("evaluate", dataset)
- return LinearSVCSummary(java_lsvc_summary)
+ s = LinearSVCSummary(java_lsvc_summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
class LinearSVCSummary(_BinaryClassificationSummary):
@@ -1578,14 +1585,16 @@ def summary(self) -> "LogisticRegressionTrainingSummary":
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
+ s: LogisticRegressionTrainingSummary
if self.numClasses <= 2:
- return BinaryLogisticRegressionTrainingSummary(
+ s = BinaryLogisticRegressionTrainingSummary(
super(LogisticRegressionModel, self).summary
)
else:
- return LogisticRegressionTrainingSummary(
- super(LogisticRegressionModel, self).summary
- )
+ s = LogisticRegressionTrainingSummary(super(LogisticRegressionModel, self).summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
@@ -1605,10 +1614,14 @@ def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_blr_summary = self._call_java("evaluate", dataset)
+ s: LogisticRegressionSummary
if self.numClasses <= 2:
- return BinaryLogisticRegressionSummary(java_blr_summary)
+ s = BinaryLogisticRegressionSummary(java_blr_summary)
else:
- return LogisticRegressionSummary(java_blr_summary)
+ s = LogisticRegressionSummary(java_blr_summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
class LogisticRegressionSummary(_ClassificationSummary):
@@ -2293,7 +2306,12 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeClassificationModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
- return [DecisionTreeClassificationModel(m) for m in self._call_java("trees").split(",")]
+ from pyspark.ml.util import RemoteModelRef
+
+ return [
+ DecisionTreeClassificationModel(RemoteModelRef(m))
+ for m in self._call_java("trees").split(",")
+ ]
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
@property
@@ -2304,22 +2322,24 @@ def summary(self) -> "RandomForestClassificationTrainingSummary":
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
+ s: RandomForestClassificationTrainingSummary
if self.numClasses <= 2:
- return BinaryRandomForestClassificationTrainingSummary(
+ s = BinaryRandomForestClassificationTrainingSummary(
super(RandomForestClassificationModel, self).summary
)
else:
- return RandomForestClassificationTrainingSummary(
+ s = RandomForestClassificationTrainingSummary(
super(RandomForestClassificationModel, self).summary
)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
)
- def evaluate(
- self, dataset: DataFrame
- ) -> Union["BinaryRandomForestClassificationSummary", "RandomForestClassificationSummary"]:
+ def evaluate(self, dataset: DataFrame) -> "RandomForestClassificationSummary":
"""
Evaluates the model on a test dataset.
@@ -2333,10 +2353,14 @@ def evaluate(
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_rf_summary = self._call_java("evaluate", dataset)
+ s: RandomForestClassificationSummary
if self.numClasses <= 2:
- return BinaryRandomForestClassificationSummary(java_rf_summary)
+ s = BinaryRandomForestClassificationSummary(java_rf_summary)
else:
- return RandomForestClassificationSummary(java_rf_summary)
+ s = RandomForestClassificationSummary(java_rf_summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
class RandomForestClassificationSummary(_ClassificationSummary):
@@ -2363,7 +2387,10 @@ class RandomForestClassificationTrainingSummary(
@inherit_doc
-class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary):
+class BinaryRandomForestClassificationSummary(
+ _BinaryClassificationSummary,
+ RandomForestClassificationSummary,
+):
"""
BinaryRandomForestClassification results for a given model.
@@ -2783,7 +2810,12 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
- return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
+ from pyspark.ml.util import RemoteModelRef
+
+ return [
+ DecisionTreeRegressionModel(RemoteModelRef(m))
+ for m in self._call_java("trees").split(",")
+ ]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
def evaluateEachIteration(self, dataset: DataFrame) -> List[float]:
@@ -3341,9 +3373,12 @@ def summary( # type: ignore[override]
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
- return MultilayerPerceptronClassificationTrainingSummary(
+ s = MultilayerPerceptronClassificationTrainingSummary(
super(MultilayerPerceptronClassificationModel, self).summary
)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
@@ -3363,7 +3398,10 @@ def evaluate(self, dataset: DataFrame) -> "MultilayerPerceptronClassificationSum
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_mlp_summary = self._call_java("evaluate", dataset)
- return MultilayerPerceptronClassificationSummary(java_mlp_summary)
+ s = MultilayerPerceptronClassificationSummary(java_mlp_summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
class MultilayerPerceptronClassificationSummary(_ClassificationSummary):
@@ -3576,46 +3614,47 @@ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
# persist if underlying dataset is not persistent.
handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
- if handlePersistence:
- multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
-
- def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, CM]]]:
- indices = iter(range(numClasses))
- def trainSingleClass() -> Tuple[int, CM]:
- index = next(indices)
-
- binaryLabelCol = "mc2b$" + str(index)
- trainingDataset = multiclassLabeled.withColumn(
- binaryLabelCol,
- F.when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
- )
- paramMap = dict(
- [
- (classifier.labelCol, binaryLabelCol),
- (classifier.featuresCol, featuresCol),
- (classifier.predictionCol, predictionCol),
- ]
- )
- if weightCol:
- paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
- return index, classifier.fit(trainingDataset, paramMap)
-
- return [trainSingleClass] * numClasses
-
- tasks = map(
- inheritable_thread_target(dataset.sparkSession),
- _oneClassFitTasks(numClasses),
- )
- pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
-
- subModels = [None] * numClasses
- for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
- assert subModels is not None
- subModels[j] = subModel
+ with _cache_spark_dataset(
+ multiclassLabeled,
+ storageLevel=StorageLevel.MEMORY_AND_DISK,
+ enable=handlePersistence,
+ ) as multiclassLabeled:
+
+ def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, CM]]]:
+ indices = iter(range(numClasses))
+
+ def trainSingleClass() -> Tuple[int, CM]:
+ index = next(indices)
+
+ binaryLabelCol = "mc2b$" + str(index)
+ trainingDataset = multiclassLabeled.withColumn(
+ binaryLabelCol,
+ F.when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
+ )
+ paramMap = dict(
+ [
+ (classifier.labelCol, binaryLabelCol),
+ (classifier.featuresCol, featuresCol),
+ (classifier.predictionCol, predictionCol),
+ ]
+ )
+ if weightCol:
+ paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
+ return index, classifier.fit(trainingDataset, paramMap)
+
+ return [trainSingleClass] * numClasses
+
+ tasks = map(
+ inheritable_thread_target(dataset.sparkSession),
+ _oneClassFitTasks(numClasses),
+ )
+ pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
- if handlePersistence:
- multiclassLabeled.unpersist()
+ subModels = [None] * numClasses
+ for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
+ assert subModels is not None
+ subModels[j] = subModel
return self._copyValues(OneVsRestModel(models=cast(List[ClassificationModel], subModels)))
@@ -3841,32 +3880,31 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
# persist if underlying dataset is not persistent.
handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
- if handlePersistence:
- newDataset.persist(StorageLevel.MEMORY_AND_DISK)
-
- # update the accumulator column with the result of prediction of models
- aggregatedDataset = newDataset
- for index, model in enumerate(self.models):
- rawPredictionCol = self.getRawPredictionCol()
-
- columns = origCols + [rawPredictionCol, accColName]
-
- # add temporary column to store intermediate scores and update
- tmpColName = "mbc$tmp" + str(uuid.uuid4())
- transformedDataset = model.transform(aggregatedDataset).select(*columns)
- updatedDataset = transformedDataset.withColumn(
- tmpColName,
- F.array_append(accColName, SF.vector_get(F.col(rawPredictionCol), F.lit(1))),
- )
- newColumns = origCols + [tmpColName]
-
- # switch out the intermediate column with the accumulator column
- aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed(
- tmpColName, accColName
- )
+ with _cache_spark_dataset(
+ newDataset,
+ storageLevel=StorageLevel.MEMORY_AND_DISK,
+ enable=handlePersistence,
+ ) as newDataset:
+ # update the accumulator column with the result of prediction of models
+ aggregatedDataset = newDataset
+ for index, model in enumerate(self.models):
+ rawPredictionCol = self.getRawPredictionCol()
+
+ columns = origCols + [rawPredictionCol, accColName]
+
+ # add temporary column to store intermediate scores and update
+ tmpColName = "mbc$tmp" + str(uuid.uuid4())
+ transformedDataset = model.transform(aggregatedDataset).select(*columns)
+ updatedDataset = transformedDataset.withColumn(
+ tmpColName,
+ F.array_append(accColName, SF.vector_get(F.col(rawPredictionCol), F.lit(1))),
+ )
+ newColumns = origCols + [tmpColName]
- if handlePersistence:
- newDataset.unpersist()
+ # switch out the intermediate column with the accumulator column
+ aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed(
+ tmpColName, accColName
+ )
if self.getRawPredictionCol():
aggregatedDataset = aggregatedDataset.withColumn(
@@ -4290,7 +4328,10 @@ def summary(self) -> "FMClassificationTrainingSummary":
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
- return FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
+ s = FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
@@ -4310,7 +4351,10 @@ def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_fm_summary = self._call_java("evaluate", dataset)
- return FMClassificationSummary(java_fm_summary)
+ s = FMClassificationSummary(java_fm_summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
class FMClassificationSummary(_BinaryClassificationSummary):
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 1ba427cda7e08..7267ee2805987 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -263,7 +263,10 @@ def summary(self) -> "GaussianMixtureSummary":
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
- return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
+ s = GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
@@ -710,7 +713,10 @@ def summary(self) -> KMeansSummary:
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
- return KMeansSummary(super(KMeansModel, self).summary)
+ s = KMeansSummary(super(KMeansModel, self).summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
@@ -1057,7 +1063,10 @@ def summary(self) -> "BisectingKMeansSummary":
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
- return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
+ s = BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
diff --git a/python/pyspark/ml/connect/__init__.py b/python/pyspark/ml/connect/__init__.py
index 875a5370d996d..6a5453db0be9c 100644
--- a/python/pyspark/ml/connect/__init__.py
+++ b/python/pyspark/ml/connect/__init__.py
@@ -16,10 +16,6 @@
#
"""Spark Connect Python Client - ML module"""
-from pyspark.sql.connect.utils import check_dependencies
-
-check_dependencies(__name__)
-
from pyspark.ml.connect.base import (
Estimator,
Transformer,
diff --git a/python/pyspark/ml/connect/base.py b/python/pyspark/ml/connect/base.py
index 516b5057cc192..32c72d5907455 100644
--- a/python/pyspark/ml/connect/base.py
+++ b/python/pyspark/ml/connect/base.py
@@ -39,7 +39,6 @@
HasFeaturesCol,
HasPredictionCol,
)
-from pyspark.ml.connect.util import transform_dataframe_column
if TYPE_CHECKING:
from pyspark.ml._typing import ParamMap
@@ -188,6 +187,8 @@ def transform(
return self._transform(dataset)
def _transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
+ from pyspark.ml.connect.util import transform_dataframe_column
+
input_cols = self._input_columns()
transform_fn = self._get_transform_fn()
output_cols = self._output_columns()
diff --git a/python/pyspark/ml/connect/evaluation.py b/python/pyspark/ml/connect/evaluation.py
index 267094f12a027..f324bb193c0ce 100644
--- a/python/pyspark/ml/connect/evaluation.py
+++ b/python/pyspark/ml/connect/evaluation.py
@@ -24,7 +24,6 @@
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasProbabilityCol
from pyspark.ml.connect.base import Evaluator
from pyspark.ml.connect.io_utils import ParamsReadWrite
-from pyspark.ml.connect.util import aggregate_dataframe
from pyspark.sql import DataFrame
@@ -56,6 +55,8 @@ def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
raise NotImplementedError()
def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float:
+ from pyspark.ml.connect.util import aggregate_dataframe
+
torch_metric = self._get_torch_metric()
def local_agg_fn(pandas_df: "pd.DataFrame") -> "pd.DataFrame":
diff --git a/python/pyspark/ml/connect/feature.py b/python/pyspark/ml/connect/feature.py
index a0e5b6a943d10..b0e2028e43faa 100644
--- a/python/pyspark/ml/connect/feature.py
+++ b/python/pyspark/ml/connect/feature.py
@@ -35,7 +35,6 @@
)
from pyspark.ml.connect.base import Estimator, Model, Transformer
from pyspark.ml.connect.io_utils import ParamsReadWrite, CoreModelReadWrite
-from pyspark.ml.connect.summarizer import summarize_dataframe
class MaxAbsScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite):
@@ -81,6 +80,8 @@ def __init__(self, *, inputCol: Optional[str] = None, outputCol: Optional[str] =
self._set(**kwargs)
def _fit(self, dataset: Union["pd.DataFrame", "DataFrame"]) -> "MaxAbsScalerModel":
+ from pyspark.ml.connect.summarizer import summarize_dataframe
+
input_col = self.getInputCol()
stat_res = summarize_dataframe(dataset, input_col, ["min", "max", "count"])
@@ -197,6 +198,8 @@ def __init__(self, inputCol: Optional[str] = None, outputCol: Optional[str] = No
self._set(**kwargs)
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "StandardScalerModel":
+ from pyspark.ml.connect.summarizer import summarize_dataframe
+
input_col = self.getInputCol()
stat_result = summarize_dataframe(dataset, input_col, ["mean", "std", "count"])
diff --git a/python/pyspark/ml/connect/functions.py b/python/pyspark/ml/connect/functions.py
index e3664db87ae64..22ff32a5946de 100644
--- a/python/pyspark/ml/connect/functions.py
+++ b/python/pyspark/ml/connect/functions.py
@@ -19,13 +19,15 @@
from pyspark.ml import functions as PyMLFunctions
from pyspark.sql.column import Column
-from pyspark.sql.connect.functions.builtin import _invoke_function, _to_col, lit
+
if TYPE_CHECKING:
from pyspark.sql._typing import UserDefinedFunctionLike
def vector_to_array(col: Column, dtype: str = "float64") -> Column:
+ from pyspark.sql.connect.functions.builtin import _invoke_function, _to_col, lit
+
return _invoke_function("vector_to_array", _to_col(col), lit(dtype))
@@ -33,6 +35,8 @@ def vector_to_array(col: Column, dtype: str = "float64") -> Column:
def array_to_vector(col: Column) -> Column:
+ from pyspark.sql.connect.functions.builtin import _invoke_function, _to_col
+
return _invoke_function("array_to_vector", _to_col(col))
@@ -49,25 +53,21 @@ def predict_batch_udf(*args: Any, **kwargs: Any) -> "UserDefinedFunctionLike":
def _test() -> None:
import os
import sys
- import doctest
- from pyspark.sql import SparkSession as PySparkSession
- import pyspark.ml.connect.functions
- from pyspark.sql.pandas.utils import (
- require_minimum_pandas_version,
- require_minimum_pyarrow_version,
- )
+ if os.environ.get("PYTHON_GIL", "?") == "0":
+ print("Not supported in no-GIL mode", file=sys.stderr)
+ sys.exit(0)
+
+ from pyspark.testing import should_test_connect
- try:
- require_minimum_pandas_version()
- require_minimum_pyarrow_version()
- except Exception as e:
- print(
- f"Skipping pyspark.ml.functions doctests: {e}",
- file=sys.stderr,
- )
+ if not should_test_connect:
+ print("Skipping pyspark.ml.connect.functions doctests", file=sys.stderr)
sys.exit(0)
+ import doctest
+ from pyspark.sql import SparkSession as PySparkSession
+ import pyspark.ml.connect.functions
+
globs = pyspark.ml.connect.functions.__dict__.copy()
globs["spark"] = (
diff --git a/python/pyspark/ml/connect/proto.py b/python/pyspark/ml/connect/proto.py
index b0e012964fc4a..31f100859281a 100644
--- a/python/pyspark/ml/connect/proto.py
+++ b/python/pyspark/ml/connect/proto.py
@@ -14,6 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
from typing import Optional, TYPE_CHECKING, List
import pyspark.sql.connect.proto as pb2
diff --git a/python/pyspark/ml/connect/readwrite.py b/python/pyspark/ml/connect/readwrite.py
index 95551f67c0120..ff53eb77d0326 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any
@@ -74,11 +77,13 @@ def saveInstance(
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
if isinstance(instance, JavaModel):
+ from pyspark.ml.util import RemoteModelRef
+
model = cast("JavaModel", instance)
params = serialize_ml_params(model, session.client)
- assert isinstance(model._java_obj, str)
+ assert isinstance(model._java_obj, RemoteModelRef)
writer = pb2.MlCommand.Write(
- obj_ref=pb2.ObjectRef(id=model._java_obj),
+ obj_ref=pb2.ObjectRef(id=model._java_obj.ref_id),
params=params,
path=path,
should_overwrite=shouldOverwrite,
@@ -267,9 +272,12 @@ def _get_class() -> Type[RL]:
py_type = _get_class()
# It must be JavaWrapper, since we're passing the string to the _java_obj
if issubclass(py_type, JavaWrapper):
+ from pyspark.ml.util import RemoteModelRef
+
if ml_type == pb2.MlOperator.OPERATOR_TYPE_MODEL:
session.client.add_ml_cache(result.obj_ref.id)
- instance = py_type(result.obj_ref.id)
+ remote_model_ref = RemoteModelRef(result.obj_ref.id)
+ instance = py_type(remote_model_ref)
else:
instance = py_type()
instance._resetUid(result.uid)
diff --git a/python/pyspark/ml/connect/serialize.py b/python/pyspark/ml/connect/serialize.py
index 42bedfb330b1b..37102d463b057 100644
--- a/python/pyspark/ml/connect/serialize.py
+++ b/python/pyspark/ml/connect/serialize.py
@@ -14,6 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
from typing import Any, List, TYPE_CHECKING, Mapping, Dict
import pyspark.sql.connect.proto as pb2
diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py
index 2bbc63ef4dc2a..1ef055d25007c 100644
--- a/python/pyspark/ml/connect/tuning.py
+++ b/python/pyspark/ml/connect/tuning.py
@@ -434,7 +434,7 @@ def _fit(self, dataset: Union[pd.DataFrame, DataFrame]) -> "CrossValidatorModel"
tasks = _parallelFitTasks(est, train, eva, validation, epm)
if not is_remote():
- tasks = list(map(inheritable_thread_target, tasks))
+ tasks = list(map(inheritable_thread_target(dataset.sparkSession), tasks))
for j, metric in pool.imap_unordered(lambda f: f(), tasks):
metrics_all[i][j] = metric
diff --git a/python/pyspark/ml/connect/util.py b/python/pyspark/ml/connect/util.py
index 1c77baeba5f88..8bf4b3480e32e 100644
--- a/python/pyspark/ml/connect/util.py
+++ b/python/pyspark/ml/connect/util.py
@@ -15,14 +15,16 @@
# limitations under the License.
#
-from typing import Any, TypeVar, Callable, List, Tuple, Union, Iterable
+from typing import Any, TypeVar, Callable, List, Tuple, Union, Iterable, TYPE_CHECKING
import pandas as pd
from pyspark import cloudpickle
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, pandas_udf
-import pyspark.sql.connect.proto as pb2
+
+if TYPE_CHECKING:
+ import pyspark.sql.connect.proto as pb2
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
@@ -180,6 +182,8 @@ def transform_fn_pandas_udf(*s: "pd.Series") -> "pd.Series":
def _extract_id_methods(obj_identifier: str) -> Tuple[List["pb2.Fetch.Method"], str]:
"""Extract the obj reference id and the methods. Eg, model.summary"""
+ import pyspark.sql.connect.proto as pb2
+
method_chain = obj_identifier.split(".")
obj_ref = method_chain[0]
methods: List["pb2.Fetch.Method"] = []
diff --git a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
index 66a9b553cc751..e614c347faa90 100644
--- a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
@@ -227,7 +227,7 @@ def setUpClass(cls) -> None:
conf = conf.set(k, v)
conf = conf.set(
"spark.worker.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name
- )
+ ).set("spark.python.unix.domain.socket.enabled", "false")
sc = SparkContext("local-cluster[2,2,512]", cls.__name__, conf=conf)
cls.spark = SparkSession(sc)
@@ -264,7 +264,7 @@ def setUpClass(cls) -> None:
conf = conf.set(k, v)
conf = conf.set(
"spark.driver.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name
- )
+ ).set("spark.python.unix.domain.socket.enabled", "false")
sc = SparkContext("local-cluster[2,2,512]", cls.__name__, conf=conf)
cls.spark = SparkSession(sc)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index d669fab27d505..4d1551652028a 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -64,6 +64,7 @@
_jvm,
)
from pyspark.ml.common import inherit_doc
+from pyspark.ml.util import RemoteModelRef
from pyspark.sql.types import ArrayType, StringType
from pyspark.sql.utils import is_remote
@@ -1224,10 +1225,12 @@ def from_vocabulary(
if is_remote():
model = CountVectorizerModel()
- model._java_obj = invoke_helper_attr(
- "countVectorizerModelFromVocabulary",
- model.uid,
- list(vocabulary),
+ model._java_obj = RemoteModelRef(
+ invoke_helper_attr(
+ "countVectorizerModelFromVocabulary",
+ model.uid,
+ list(vocabulary),
+ )
)
else:
@@ -4843,10 +4846,12 @@ def from_labels(
"""
if is_remote():
model = StringIndexerModel()
- model._java_obj = invoke_helper_attr(
- "stringIndexerModelFromLabels",
- model.uid,
- (list(labels), ArrayType(StringType())),
+ model._java_obj = RemoteModelRef(
+ invoke_helper_attr(
+ "stringIndexerModelFromLabels",
+ model.uid,
+ (list(labels), ArrayType(StringType())),
+ )
)
else:
@@ -4882,13 +4887,15 @@ def from_arrays_of_labels(
"""
if is_remote():
model = StringIndexerModel()
- model._java_obj = invoke_helper_attr(
- "stringIndexerModelFromLabelsArray",
- model.uid,
- (
- [list(labels) for labels in arrayOfLabels],
- ArrayType(ArrayType(StringType())),
- ),
+ model._java_obj = RemoteModelRef(
+ invoke_helper_attr(
+ "stringIndexerModelFromLabelsArray",
+ model.uid,
+ (
+ [list(labels) for labels in arrayOfLabels],
+ ArrayType(ArrayType(StringType())),
+ ),
+ )
)
else:
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index b4f9c6000b63e..66d6dbd6a2678 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -487,7 +487,10 @@ def summary(self) -> "LinearRegressionTrainingSummary":
`trainingSummary is None`.
"""
if self.hasSummary:
- return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
+ s = LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
@@ -508,7 +511,10 @@ def evaluate(self, dataset: DataFrame) -> "LinearRegressionSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_lr_summary = self._call_java("evaluate", dataset)
- return LinearRegressionSummary(java_lr_summary)
+ s = LinearRegressionSummary(java_lr_summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
class LinearRegressionSummary(JavaWrapper):
@@ -1608,7 +1614,12 @@ class RandomForestRegressionModel(
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
- return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
+ from pyspark.ml.util import RemoteModelRef
+
+ return [
+ DecisionTreeRegressionModel(RemoteModelRef(m))
+ for m in self._call_java("trees").split(",")
+ ]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
@property
@@ -1999,7 +2010,12 @@ def featureImportances(self) -> Vector:
def trees(self) -> List[DecisionTreeRegressionModel]:
"""Trees in this ensemble. Warning: These have null parent Estimators."""
if is_remote():
- return [DecisionTreeRegressionModel(m) for m in self._call_java("trees").split(",")]
+ from pyspark.ml.util import RemoteModelRef
+
+ return [
+ DecisionTreeRegressionModel(RemoteModelRef(m))
+ for m in self._call_java("trees").split(",")
+ ]
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]:
@@ -2766,9 +2782,12 @@ def summary(self) -> "GeneralizedLinearRegressionTrainingSummary":
`trainingSummary is None`.
"""
if self.hasSummary:
- return GeneralizedLinearRegressionTrainingSummary(
+ s = GeneralizedLinearRegressionTrainingSummary(
super(GeneralizedLinearRegressionModel, self).summary
)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
@@ -2789,7 +2808,10 @@ def evaluate(self, dataset: DataFrame) -> "GeneralizedLinearRegressionSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_glr_summary = self._call_java("evaluate", dataset)
- return GeneralizedLinearRegressionSummary(java_glr_summary)
+ s = GeneralizedLinearRegressionSummary(java_glr_summary)
+ if is_remote():
+ s.__source_transformer__ = self # type: ignore[attr-defined]
+ return s
class GeneralizedLinearRegressionSummary(JavaWrapper):
diff --git a/python/pyspark/ml/tests/connect/test_connect_cache.py b/python/pyspark/ml/tests/connect/test_connect_cache.py
new file mode 100644
index 0000000000000..8d31328dc7c95
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_connect_cache.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.classification import LinearSVC
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class MLConnectCacheTests(ReusedConnectTestCase):
+ def test_delete_model(self):
+ spark = self.spark
+ df = (
+ spark.createDataFrame(
+ [
+ (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+ (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+ (0.0, 4.0, Vectors.dense(3.0, 3.0)),
+ ],
+ ["label", "weight", "features"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("weight")
+ )
+ svc = LinearSVC(maxIter=1, regParam=1.0)
+
+ model = svc.fit(df)
+
+ # model is cached in python side
+ self.assertEqual(len(spark.client.thread_local.ml_caches), 1)
+ cache_info = spark.client._get_ml_cache_info()
+ self.assertEqual(len(cache_info), 1)
+ self.assertTrue(
+ "obj: class org.apache.spark.ml.classification.LinearSVCModel" in cache_info[0],
+ cache_info,
+ )
+
+ # explicitly delete the model
+ del model
+
+ # model is removed in python side
+ self.assertEqual(len(spark.client.thread_local.ml_caches), 0)
+ cache_info = spark.client._get_ml_cache_info()
+ self.assertEqual(len(cache_info), 0)
+
+ def test_cleanup_ml_cache(self):
+ spark = self.spark
+ df = (
+ spark.createDataFrame(
+ [
+ (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+ (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+ (0.0, 4.0, Vectors.dense(3.0, 3.0)),
+ ],
+ ["label", "weight", "features"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("weight")
+ )
+
+ svc = LinearSVC(maxIter=1, regParam=1.0)
+ model1 = svc.fit(df)
+ model2 = svc.fit(df)
+ model3 = svc.fit(df)
+ self.assertEqual(len([model1, model2, model3]), 3)
+
+ # all 3 models are cached in python side
+ self.assertEqual(len(spark.client.thread_local.ml_caches), 3)
+ cache_info = spark.client._get_ml_cache_info()
+ self.assertEqual(len(cache_info), 3)
+ self.assertTrue(
+ all(
+ "obj: class org.apache.spark.ml.classification.LinearSVCModel" in c
+ for c in cache_info
+ ),
+ cache_info,
+ )
+
+ # explicitly delete the model1
+ del model1
+
+ # model1 is removed in python side
+ self.assertEqual(len(spark.client.thread_local.ml_caches), 2)
+ cache_info = spark.client._get_ml_cache_info()
+ self.assertEqual(len(cache_info), 2)
+
+ spark.client._cleanup_ml_cache()
+
+ # All models are removed in python side
+ self.assertEqual(len(spark.client.thread_local.ml_caches), 0)
+ cache_info = spark.client._get_ml_cache_info()
+ self.assertEqual(len(cache_info), 0)
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.connect.test_connect_cache import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py
index d3e86a3fb9df7..e9ccd7b0369e5 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -20,10 +20,10 @@
import os
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
+from pyspark.testing.connectutils import ReusedConnectTestCase
@unittest.skipIf(
@@ -32,16 +32,16 @@
or torch_requirement_message
or "Requires PySpark core library in Spark Connect server",
)
-class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = (
- SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]"))
- .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
- .getOrCreate()
- )
-
- def tearDown(self) -> None:
- self.spark.stop()
+class ClassificationTestsOnConnect(ClassificationTestsMixin, ReusedConnectTestCase):
+ @classmethod
+ def conf(cls):
+ config = super(ClassificationTestsOnConnect, cls).conf()
+ config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
+ return config
+
+ @classmethod
+ def master(cls):
+ return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
index 662fe8a2ffdfc..73b9e0943bea8 100644
--- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
@@ -18,21 +18,17 @@
import os
import unittest
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect
+from pyspark.testing.connectutils import ReusedConnectTestCase
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_evaluation import EvaluationTestsMixin
@unittest.skip("SPARK-50956: Flaky with RetriesExceeded")
- class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.remote(
- os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
- ).getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+ class EvaluationTestsOnConnect(EvaluationTestsMixin, ReusedConnectTestCase):
+ @classmethod
+ def master(cls):
+ return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py b/python/pyspark/ml/tests/connect/test_connect_feature.py
index 879cbff6d0cc7..04a8d3664b96c 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -18,9 +18,9 @@
import os
import unittest
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_sklearn, sklearn_requirement_message
+from pyspark.testing.connectutils import ReusedConnectTestCase
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_feature import FeatureTestsMixin
@@ -29,14 +29,10 @@
not should_test_connect or not have_sklearn,
connect_requirement_message or sklearn_requirement_message,
)
- class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.remote(
- os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
- ).getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+ class FeatureTestsOnConnect(FeatureTestsMixin, ReusedConnectTestCase):
+ @classmethod
+ def master(cls):
+ return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_function.py b/python/pyspark/ml/tests/connect/test_connect_function.py
index 7d3a115ab0619..3d428232ebb38 100644
--- a/python/pyspark/ml/tests/connect/test_connect_function.py
+++ b/python/pyspark/ml/tests/connect/test_connect_function.py
@@ -14,62 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import os
import unittest
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession as PySparkSession
from pyspark.ml import functions as SF
-from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.connectutils import (
should_test_connect,
- ReusedConnectTestCase,
+ ReusedMixedTestCase,
)
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
if should_test_connect:
- from pyspark.sql.connect.dataframe import DataFrame as CDF
from pyspark.ml.connect import functions as CF
@unittest.skipIf(is_remote_only(), "Requires JVM access")
-class SparkConnectMLFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils):
+class SparkConnectMLFunctionTests(ReusedMixedTestCase, PandasOnSparkTestUtils):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
- @classmethod
- def setUpClass(cls):
- super(SparkConnectMLFunctionTests, cls).setUpClass()
- # Disable the shared namespace so pyspark.sql.functions, etc point the regular
- # PySpark libraries.
- os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
- cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
- cls.spark = PySparkSession._instantiatedSession
- assert cls.spark is not None
-
- @classmethod
- def tearDownClass(cls):
- cls.spark = cls.connect # Stopping Spark Connect closes the session in JVM at the server.
- super(SparkConnectMLFunctionTests, cls).tearDownClass()
- del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
-
- def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
- from pyspark.sql.classic.dataframe import DataFrame as SDF
-
- assert isinstance(df1, (SDF, CDF))
- if isinstance(df1, SDF):
- str1 = df1._jdf.showString(n, truncate, False)
- else:
- str1 = df1._show_string(n, truncate, False)
-
- assert isinstance(df2, (SDF, CDF))
- if isinstance(df2, SDF):
- str2 = df2._jdf.showString(n, truncate, False)
- else:
- str2 = df2._show_string(n, truncate, False)
-
- self.assertEqual(str1, str2)
-
def test_array_vector_conversion(self):
query = """
SELECT * FROM VALUES
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index f8576d0cb09da..2b408911fbd27 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -15,14 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+
import os
import unittest
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
-
+from pyspark.testing.connectutils import ReusedConnectTestCase
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixin
@@ -33,18 +33,16 @@
or torch_requirement_message
or "Requires PySpark core library in Spark Connect server",
)
- class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = (
- SparkSession.builder.remote(
- os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
- )
- .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
- .getOrCreate()
- )
-
- def tearDown(self) -> None:
- self.spark.stop()
+ class PipelineTestsOnConnect(PipelineTestsMixin, ReusedConnectTestCase):
+ @classmethod
+ def conf(cls):
+ config = super(PipelineTestsOnConnect, cls).conf()
+ config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
+ return config
+
+ @classmethod
+ def master(cls):
+ return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
index 9c737c96ee87a..57911779d6bb8 100644
--- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
@@ -15,24 +15,20 @@
# limitations under the License.
#
-import unittest
import os
+import unittest
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
+from pyspark.testing.connectutils import ReusedConnectTestCase
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_summarizer import SummarizerTestsMixin
@unittest.skipIf(not should_test_connect, connect_requirement_message)
- class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.remote(
- os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
- ).getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+ class SummarizerTestsOnConnect(SummarizerTestsMixin, ReusedConnectTestCase):
+ @classmethod
+ def master(cls):
+ return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index d737dd5767dbd..3b7f977b57ae2 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -16,13 +16,13 @@
# limitations under the License.
#
-import unittest
import os
+import unittest
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
+from pyspark.testing.connectutils import ReusedConnectTestCase
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_tuning import CrossValidatorTestsMixin
@@ -33,18 +33,16 @@
or torch_requirement_message
or "Requires PySpark core library in Spark Connect server",
)
- class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = (
- SparkSession.builder.remote(
- os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
- )
- .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
- .getOrCreate()
- )
-
- def tearDown(self) -> None:
- self.spark.stop()
+ class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, ReusedConnectTestCase):
+ @classmethod
+ def conf(cls):
+ config = super(CrossValidatorTestsOnConnect, cls).conf()
+ config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
+ return config
+
+ @classmethod
+ def master(cls):
+ return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
index fdae31077002e..aad668079ec0e 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
@@ -22,9 +22,9 @@
import numpy as np
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
+from pyspark.testing.sqlutils import ReusedSQLTestCase
if should_test_connect:
from pyspark.ml.connect.classification import (
@@ -231,12 +231,10 @@ def test_save_load(self):
or torch_requirement_message
or "pyspark-connect cannot test classic Spark",
)
-class ClassificationTests(ClassificationTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.master("local[2]").getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+class ClassificationTests(ClassificationTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def master(cls):
+ return "local[2]"
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
index 3a5417dadf50a..df0ad8b23f8fc 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
@@ -21,9 +21,9 @@
import numpy as np
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torcheval, torcheval_requirement_message
+from pyspark.testing.sqlutils import ReusedSQLTestCase
if should_test_connect:
from pyspark.ml.connect.evaluation import (
@@ -178,12 +178,10 @@ def test_multiclass_classifier_evaluator(self):
or torcheval_requirement_message
or "pyspark-connect cannot test classic Spark",
)
-class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.master("local[2]").getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+class EvaluationTests(EvaluationTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def master(cls):
+ return "local[2]"
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
index 6812db778450a..2d0a37aca5c8c 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
@@ -24,9 +24,9 @@
import numpy as np
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
+from pyspark.testing.sqlutils import ReusedSQLTestCase
if should_test_connect:
from pyspark.ml.connect.feature import (
@@ -201,12 +201,10 @@ def test_array_assembler(self):
or torch_requirement_message
or "pyspark-connect cannot test classic Spark",
)
-class FeatureTests(FeatureTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.master("local[2]").getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+class FeatureTests(FeatureTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def master(cls):
+ return "local[2]"
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
index 8b19f5931d207..125c05183f247 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
@@ -22,9 +22,9 @@
import numpy as np
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
+from pyspark.testing.sqlutils import ReusedSQLTestCase
if should_test_connect:
from pyspark.ml.connect.feature import StandardScaler
@@ -175,12 +175,10 @@ def test_pipeline_copy():
or torch_requirement_message
or "pyspark-connect cannot test classic Spark",
)
-class PipelineTests(PipelineTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.master("local[2]").getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def master(cls):
+ return "local[2]"
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
index 253632a74c973..7cfc869e822c8 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
@@ -21,8 +21,8 @@
import numpy as np
from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
+from pyspark.testing.sqlutils import ReusedSQLTestCase
if should_test_connect:
from pyspark.ml.connect.summarizer import summarize_dataframe
@@ -67,12 +67,10 @@ def assert_dict_allclose(dict1, dict2):
not should_test_connect or is_remote_only(),
connect_requirement_message or "pyspark-connect cannot test classic Spark",
)
-class SummarizerTests(SummarizerTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.master("local[2]").getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+class SummarizerTests(SummarizerTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def master(cls):
+ return "local[2]"
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
index 06c3ad93d92d2..6cee852247ad0 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -25,7 +25,6 @@
from pyspark.util import is_remote_only
from pyspark.ml.param import Param, Params
from pyspark.ml.tuning import ParamGridBuilder
-from pyspark.sql import SparkSession
from pyspark.sql.functions import rand
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import (
@@ -36,6 +35,7 @@
have_torcheval,
torcheval_requirement_message,
)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
if should_test_connect:
import pandas as pd
@@ -294,12 +294,10 @@ def test_crossvalidator_with_fold_col(self):
or torcheval_requirement_message
or "pyspark-connect cannot test classic Spark",
)
-class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase):
- def setUp(self) -> None:
- self.spark = SparkSession.builder.master("local[2]").getOrCreate()
-
- def tearDown(self) -> None:
- self.spark.stop()
+class CrossValidatorTests(CrossValidatorTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def master(cls):
+ return "local[2]"
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
index 6c747be977736..fc0500b8e83a7 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
@@ -21,9 +21,10 @@
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
+from pyspark.testing import should_test_connect, connect_requirement_message
from pyspark.testing.utils import have_torch, torch_requirement_message
-if not is_remote_only():
+if not is_remote_only() and should_test_connect:
from pyspark.ml.torch.tests.test_distributor import (
TorchDistributorBaselineUnitTestsMixin,
TorchDistributorLocalUnitTestsMixin,
@@ -35,7 +36,8 @@
)
@unittest.skipIf(
- not have_torch or is_remote_only(), torch_requirement_message or "Requires JVM access"
+ not should_test_connect or not have_torch or is_remote_only(),
+ connect_requirement_message or torch_requirement_message or "Requires JVM access",
)
class TorchDistributorBaselineUnitTestsOnConnect(
TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
@@ -142,7 +144,8 @@ def tearDownClass(cls):
cls.spark.stop()
@unittest.skipIf(
- not have_torch or is_remote_only(), torch_requirement_message or "Requires JVM access"
+ not should_test_connect or not have_torch or is_remote_only(),
+ connect_requirement_message or torch_requirement_message or "Requires JVM access",
)
class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin, unittest.TestCase):
pass
diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py
index d0e2600a9a8b3..0f5deab4e0935 100644
--- a/python/pyspark/ml/tests/test_algorithms.py
+++ b/python/pyspark/ml/tests/test_algorithms.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import os
from shutil import rmtree
import tempfile
import unittest
@@ -154,6 +154,28 @@ def test_support_for_weightCol(self):
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
self.assertIsNotNone(ovr2.fit(df))
+ def test_tmp_dfs_cache(self):
+ from pyspark.ml.util import _SPARKML_TEMP_DFS_PATH
+
+ with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
+ os.environ[_SPARKML_TEMP_DFS_PATH] = d
+ try:
+ df = self.spark.createDataFrame(
+ [
+ (0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5)),
+ ],
+ ["label", "features"],
+ )
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr, parallelism=1)
+ model = ovr.fit(df)
+ model.transform(df)
+ assert len(os.listdir(d)) == 0
+ finally:
+ os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
+
class KMeansTests(SparkSessionTestCase):
def test_kmeans_cosine_distance(self):
diff --git a/python/pyspark/ml/tests/test_classification.py b/python/pyspark/ml/tests/test_classification.py
index 5c793dc344c7b..57e4c0ef86dc0 100644
--- a/python/pyspark/ml/tests/test_classification.py
+++ b/python/pyspark/ml/tests/test_classification.py
@@ -21,6 +21,7 @@
import numpy as np
+from pyspark.errors import PySparkException
from pyspark.ml.linalg import Vectors, Matrices
from pyspark.sql import DataFrame, Row
from pyspark.ml.classification import (
@@ -978,6 +979,10 @@ def test_mlp(self):
model2 = MultilayerPerceptronClassificationModel.load(d)
self.assertEqual(str(model), str(model2))
+ def test_invalid_load_location(self):
+ with self.assertRaisesRegex(PySparkException, "Path does not exist"):
+ LogisticRegression.load("invalid_location")
+
class ClassificationTests(ClassificationTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/ml/tests/test_clustering.py b/python/pyspark/ml/tests/test_clustering.py
index f89c7305fc9ce..1b8eb73135a96 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -37,6 +37,7 @@
DistributedLDAModel,
PowerIterationClustering,
)
+from pyspark.sql import is_remote
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -377,6 +378,8 @@ def test_local_lda(self):
self.assertEqual(str(model), str(model2))
def test_distributed_lda(self):
+ if is_remote():
+ self.skipTest("Do not support Spark Connect.")
spark = self.spark
df = (
spark.createDataFrame(
@@ -404,6 +407,7 @@ def test_distributed_lda(self):
self.assertNotIsInstance(model, LocalLDAModel)
self.assertIsInstance(model, DistributedLDAModel)
self.assertTrue(model.isDistributed())
+ self.assertEqual(model.vocabSize(), 2)
dc = model.estimatedDocConcentration()
self.assertTrue(np.allclose(dc.toArray(), [26.0, 26.0], atol=1e-4), dc)
diff --git a/python/pyspark/ml/tests/test_fpm.py b/python/pyspark/ml/tests/test_fpm.py
index ea94216c98608..7b949763c3988 100644
--- a/python/pyspark/ml/tests/test_fpm.py
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -18,7 +18,7 @@
import tempfile
import unittest
-from pyspark.sql import Row
+from pyspark.sql import is_remote, Row
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
@@ -30,6 +30,8 @@
class FPMTestsMixin:
def test_fp_growth(self):
+ if is_remote():
+ self.skipTest("Do not support Spark Connect.")
df = self.spark.createDataFrame(
[
["r z h k p"],
diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py
index ced1cda1948a6..892ce72e32bc4 100644
--- a/python/pyspark/ml/tests/test_pipeline.py
+++ b/python/pyspark/ml/tests/test_pipeline.py
@@ -29,7 +29,7 @@
)
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel
-from pyspark.ml.clustering import KMeans, KMeansModel
+from pyspark.ml.clustering import KMeans, KMeansModel, GaussianMixture
from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -176,7 +176,7 @@ def test_clustering_pipeline(self):
def test_model_gc(self):
spark = self.spark
- df = spark.createDataFrame(
+ df1 = spark.createDataFrame(
[
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
@@ -189,8 +189,107 @@ def fit_transform(df):
model = lr.fit(df)
return model.transform(df)
- output = fit_transform(df)
- self.assertEqual(output.count(), 3)
+ output1 = fit_transform(df1)
+ self.assertEqual(output1.count(), 3)
+
+ df2 = spark.range(10)
+
+ def fit_transform_and_union(df1, df2):
+ output1 = fit_transform(df1)
+ return output1.unionByName(df2, True)
+
+ output2 = fit_transform_and_union(df1, df2)
+ self.assertEqual(output2.count(), 13)
+
+ def test_model_training_summary_gc(self):
+ spark = self.spark
+ df1 = spark.createDataFrame(
+ [
+ Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
+ Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
+ Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
+ ]
+ )
+
+ def fit_predictions(df):
+ lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
+ model = lr.fit(df)
+ return model.summary.predictions
+
+ output1 = fit_predictions(df1)
+ self.assertEqual(output1.count(), 3)
+
+ df2 = spark.range(10)
+
+ def fit_predictions_and_union(df1, df2):
+ output1 = fit_predictions(df1)
+ return output1.unionByName(df2, True)
+
+ output2 = fit_predictions_and_union(df1, df2)
+ self.assertEqual(output2.count(), 13)
+
+ def test_model_testing_summary_gc(self):
+ spark = self.spark
+ df1 = spark.createDataFrame(
+ [
+ Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
+ Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
+ Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
+ ]
+ )
+
+ def fit_predictions(df):
+ lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
+ model = lr.fit(df)
+ return model.evaluate(df).predictions
+
+ output1 = fit_predictions(df1)
+ self.assertEqual(output1.count(), 3)
+
+ df2 = spark.range(10)
+
+ def fit_predictions_and_union(df1, df2):
+ output1 = fit_predictions(df1)
+ return output1.unionByName(df2, True)
+
+ output2 = fit_predictions_and_union(df1, df2)
+ self.assertEqual(output2.count(), 13)
+
+ def test_model_attr_df_gc(self):
+ spark = self.spark
+ df1 = (
+ spark.createDataFrame(
+ [
+ (1, 1.0, Vectors.dense([-0.1, -0.05])),
+ (2, 2.0, Vectors.dense([-0.01, -0.1])),
+ (3, 3.0, Vectors.dense([0.9, 0.8])),
+ (4, 1.0, Vectors.dense([0.75, 0.935])),
+ (5, 1.0, Vectors.dense([-0.83, -0.68])),
+ (6, 1.0, Vectors.dense([-0.91, -0.76])),
+ ],
+ ["index", "weight", "features"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("index")
+ .select("weight", "features")
+ )
+
+ def fit_attr_df(df):
+ gmm = GaussianMixture(k=2, maxIter=2, weightCol="weight", seed=1)
+ model = gmm.fit(df)
+ return model.gaussiansDF
+
+ output1 = fit_attr_df(df1)
+ self.assertEqual(output1.count(), 2)
+
+ df2 = spark.range(10)
+
+ def fit_attr_df_and_union(df1, df2):
+ output1 = fit_attr_df(df1)
+ return output1.unionByName(df2, True)
+
+ output2 = fit_attr_df_and_union(df1, df2)
+ self.assertEqual(output2.count(), 12)
class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py
index 94081a090bfe4..ff9a26f711975 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import os
import tempfile
import unittest
@@ -22,7 +22,7 @@
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.linalg import Vectors
-from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml.tuning import (
ParamGridBuilder,
CrossValidator,
@@ -30,6 +30,7 @@
TrainValidationSplit,
TrainValidationSplitModel,
)
+from pyspark.ml.util import _SPARKML_TEMP_DFS_PATH
from pyspark.sql.functions import rand
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -70,6 +71,18 @@ def test_train_validation_split(self):
evaluation_score = evaluator.evaluate(tvs_model.transform(dataset))
self.assertTrue(np.isclose(evaluation_score, 0.8333333333333333, atol=1e-4))
+ with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
+ os.environ[_SPARKML_TEMP_DFS_PATH] = d
+ try:
+ tvs_model2 = tvs.fit(dataset)
+ assert len(os.listdir(d)) == 0
+ self.assertTrue(np.isclose(tvs_model2.validationMetrics[0], 0.5, atol=1e-4))
+ self.assertTrue(
+ np.isclose(tvs_model2.validationMetrics[1], 0.8857142857142857, atol=1e-4)
+ )
+ finally:
+ os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
+
# save & load
with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
tvs.write().overwrite().save(d)
@@ -118,6 +131,15 @@ def test_cross_validator(self):
self.assertEqual(model.getEstimatorParamMaps(), grid)
self.assertTrue(np.isclose(model.avgMetrics[0], 0.5, atol=1e-4))
+ with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
+ os.environ[_SPARKML_TEMP_DFS_PATH] = d
+ try:
+ model2 = cv.fit(dataset)
+ assert len(os.listdir(d)) == 0
+ self.assertTrue(np.isclose(model2.avgMetrics[0], 0.5, atol=1e-4))
+ finally:
+ os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
+
output = model.transform(dataset)
self.assertEqual(
output.columns, ["features", "label", "rawPrediction", "probability", "prediction"]
@@ -230,6 +252,28 @@ def test_cv_invalid_user_specified_folds(self):
with self.assertRaisesRegex(Exception, "The validation data at fold 3 is empty"):
cv.fit(dataset_with_folds)
+ def test_crossvalidator_with_random_forest_classifier(self):
+ dataset = self.spark.createDataFrame(
+ [
+ (Vectors.dense(1.0, 2.0), 0),
+ (Vectors.dense(2.0, 3.0), 1),
+ (Vectors.dense(1.5, 2.5), 0),
+ (Vectors.dense(3.0, 4.0), 1),
+ (Vectors.dense(1.1, 2.1), 0),
+ (Vectors.dense(2.5, 3.5), 1),
+ ],
+ ["features", "label"],
+ )
+ rf = RandomForestClassifier(labelCol="label", featuresCol="features")
+ evaluator = BinaryClassificationEvaluator(labelCol="label")
+ paramGrid = (
+ ParamGridBuilder().addGrid(rf.maxDepth, [2]).addGrid(rf.numTrees, [5, 10]).build()
+ )
+ cv = CrossValidator(
+ estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3
+ )
+ cv.fit(dataset)
+
class TuningTests(TuningTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index a9892dc8db36d..498195baf1eb7 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -55,11 +55,10 @@
JavaMLWriter,
try_remote_write,
try_remote_read,
- is_remote,
+ _cache_spark_dataset,
)
from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
-from pyspark.sql.functions import col, lit, rand
-from pyspark.sql.types import BooleanType
+from pyspark.sql import functions as F
from pyspark.sql.dataframe import DataFrame
if TYPE_CHECKING:
@@ -849,22 +848,23 @@ def _fit(self, dataset: DataFrame) -> "CrossValidatorModel":
subModels = [[None for j in range(numModels)] for i in range(nFolds)]
datasets = self._kFold(dataset)
- for i in range(nFolds):
- validation = datasets[i][1].cache()
- train = datasets[i][0].cache()
-
- tasks = map(
- inheritable_thread_target(dataset.sparkSession),
- _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
- )
- for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
- metrics_all[i][j] = metric
- if collectSubModelsParam:
- assert subModels is not None
- subModels[i][j] = subModel
- validation.unpersist()
- train.unpersist()
+ for i in range(nFolds):
+ validation = datasets[i][1]
+ train = datasets[i][0]
+
+ with _cache_spark_dataset(train) as train, _cache_spark_dataset(
+ validation
+ ) as validation:
+ tasks = map(
+ inheritable_thread_target(dataset.sparkSession),
+ _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
+ )
+ for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
+ metrics_all[i][j] = metric
+ if collectSubModelsParam:
+ assert subModels is not None
+ subModels[i][j] = subModel
metrics, std_metrics = CrossValidator._gen_avg_and_std_metrics(metrics_all)
@@ -887,7 +887,7 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
seed = self.getOrDefault(self.seed)
h = 1.0 / nFolds
randCol = self.uid + "_rand"
- df = dataset.select("*", rand(seed).alias(randCol))
+ df = dataset.select("*", F.rand(seed).alias(randCol))
for i in range(nFolds):
validateLB = i * h
validateUB = (i + 1) * h
@@ -897,38 +897,27 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
datasets.append((train, validation))
else:
# Use user-specified fold numbers.
- def checker(foldNum: int) -> bool:
- if foldNum < 0 or foldNum >= nFolds:
- raise ValueError(
- "Fold number must be in range [0, %s), but got %s." % (nFolds, foldNum)
- )
- return True
-
- if is_remote():
- from pyspark.sql.connect.udf import UserDefinedFunction
- else:
- from pyspark.sql.functions import UserDefinedFunction # type: ignore[assignment]
- from pyspark.util import PythonEvalType
-
- # TODO(SPARK-48515): Use Arrow Python UDF
- checker_udf = UserDefinedFunction(
- checker, BooleanType(), evalType=PythonEvalType.SQL_BATCHED_UDF
+ checked = dataset.withColumn(
+ foldCol,
+ F.when(
+ (F.lit(0) <= F.col(foldCol)) & (F.col(foldCol) < F.lit(nFolds)), F.col(foldCol)
+ ).otherwise(
+ F.raise_error(
+ F.printf(
+ F.lit(f"Fold number must be in range [0, {nFolds}), but got %s"),
+ F.col(foldCol),
+ )
+ ),
+ ),
)
+
for i in range(nFolds):
- training = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) != lit(i)))
- validation = dataset.filter(
- checker_udf(dataset[foldCol]) & (col(foldCol) == lit(i))
- )
- if is_remote():
- if len(training.take(1)) == 0:
- raise ValueError("The training data at fold %s is empty." % i)
- if len(validation.take(1)) == 0:
- raise ValueError("The validation data at fold %s is empty." % i)
- else:
- if training.rdd.getNumPartitions() == 0 or len(training.take(1)) == 0:
- raise ValueError("The training data at fold %s is empty." % i)
- if validation.rdd.getNumPartitions() == 0 or len(validation.take(1)) == 0:
- raise ValueError("The validation data at fold %s is empty." % i)
+ training = checked.filter(F.col(foldCol) != F.lit(i))
+ validation = checked.filter(F.col(foldCol) == F.lit(i))
+ if training.isEmpty():
+ raise ValueError("The training data at fold %s is empty." % i)
+ if validation.isEmpty():
+ raise ValueError("The validation data at fold %s is empty." % i)
datasets.append((training, validation))
return datasets
@@ -1486,30 +1475,29 @@ def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel":
tRatio = self.getOrDefault(self.trainRatio)
seed = self.getOrDefault(self.seed)
randCol = self.uid + "_rand"
- df = dataset.select("*", rand(seed).alias(randCol))
+ df = dataset.select("*", F.rand(seed).alias(randCol))
condition = df[randCol] >= tRatio
- validation = df.filter(condition).cache()
- train = df.filter(~condition).cache()
- subModels = None
- collectSubModelsParam = self.getCollectSubModels()
- if collectSubModelsParam:
- subModels = [None for i in range(numModels)]
+ validation = df.filter(condition)
+ train = df.filter(~condition)
- tasks = map(
- inheritable_thread_target(dataset.sparkSession),
- _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
- )
- pool = ThreadPool(processes=min(self.getParallelism(), numModels))
- metrics = [None] * numModels
- for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
- metrics[j] = metric
+ with _cache_spark_dataset(train) as train, _cache_spark_dataset(validation) as validation:
+ subModels = None
+ collectSubModelsParam = self.getCollectSubModels()
if collectSubModelsParam:
- assert subModels is not None
- subModels[j] = subModel
+ subModels = [None for i in range(numModels)]
- train.unpersist()
- validation.unpersist()
+ tasks = map(
+ inheritable_thread_target(dataset.sparkSession),
+ _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
+ )
+ pool = ThreadPool(processes=min(self.getParallelism(), numModels))
+ metrics = [None] * numModels
+ for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
+ metrics[j] = metric
+ if collectSubModelsParam:
+ assert subModels is not None
+ subModels[j] = subModel
if eva.isLargerBetter():
bestIndex = np.argmax(cast(List[float], metrics))
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 6b3d6101c249f..a5e0c847c1732 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -17,6 +17,7 @@
import json
import os
+import threading
import time
import uuid
import functools
@@ -25,6 +26,7 @@
Callable,
Dict,
Generic,
+ Iterator,
List,
Optional,
Sequence,
@@ -34,11 +36,13 @@
TYPE_CHECKING,
Union,
)
+from contextlib import contextmanager
from pyspark import since
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
from pyspark.sql.utils import is_remote
+from pyspark.storagelevel import StorageLevel
from pyspark.util import VersionUtils
if TYPE_CHECKING:
@@ -47,6 +51,7 @@
from pyspark.ml.base import Params
from pyspark.ml.wrapper import JavaWrapper
from pyspark.core.context import SparkContext
+ from pyspark.sql import DataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.ml.wrapper import JavaWrapper, JavaEstimator
from pyspark.ml.evaluation import JavaEvaluator
@@ -71,7 +76,7 @@ def try_remote_intermediate_result(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(self: "JavaWrapper") -> Any:
if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
- return f"{self._java_obj}.{f.__name__}"
+ return f"{str(self._java_obj)}.{f.__name__}"
else:
return f(self)
@@ -104,15 +109,25 @@ def invoke_remote_attribute_relation(
from pyspark.ml.connect.proto import AttributeRelation
from pyspark.sql.connect.session import SparkSession
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+ from pyspark.ml.wrapper import JavaModel
session = SparkSession.getActiveSession()
assert session is not None
- assert isinstance(instance._java_obj, str)
-
- methods, obj_ref = _extract_id_methods(instance._java_obj)
+ if isinstance(instance, JavaModel):
+ assert isinstance(instance._java_obj, RemoteModelRef)
+ object_id = instance._java_obj.ref_id
+ else:
+ # model summary
+ object_id = instance._java_obj # type: ignore
+ methods, obj_ref = _extract_id_methods(object_id)
methods.append(pb2.Fetch.Method(method=method, args=serialize(session.client, *args)))
plan = AttributeRelation(obj_ref, methods)
+
+ # To delay the GC of the model, keep a reference to the source instance,
+ # might be a model or a summary.
+ plan.__source_instance__ = instance # type: ignore[attr-defined]
+
return ConnectDataFrame(plan, session)
@@ -130,6 +145,33 @@ def wrapped(self: "JavaWrapper", *args: Any, **kwargs: Any) -> Any:
return cast(FuncT, wrapped)
+class RemoteModelRef:
+ def __init__(self, ref_id: str) -> None:
+ self._ref_id = ref_id
+ self._ref_count = 1
+ self._lock = threading.Lock()
+
+ @property
+ def ref_id(self) -> str:
+ return self._ref_id
+
+ def add_ref(self) -> None:
+ with self._lock:
+ assert self._ref_count > 0
+ self._ref_count += 1
+
+ def release_ref(self) -> None:
+ with self._lock:
+ assert self._ref_count > 0
+ self._ref_count -= 1
+ if self._ref_count == 0:
+ # Delete the model if possible
+ del_remote_cache(self.ref_id)
+
+ def __str__(self) -> str:
+ return self.ref_id
+
+
def try_remote_fit(f: FuncT) -> FuncT:
"""Mark the function that fits a model."""
@@ -156,7 +198,8 @@ def wrapped(self: "JavaEstimator", dataset: "ConnectDataFrame") -> Any:
(_, properties, _) = client.execute_command(command)
model_info = deserialize(properties)
client.add_ml_cache(model_info.obj_ref.id)
- model = self._create_model(model_info.obj_ref.id)
+ remote_model_ref = RemoteModelRef(model_info.obj_ref.id)
+ model = self._create_model(remote_model_ref)
if model.__class__.__name__ not in ["Bucketizer"]:
model._resetUid(self.uid)
return self._copyValues(model)
@@ -183,31 +226,42 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
if isinstance(self, Model):
from pyspark.ml.connect.proto import TransformerRelation
- assert isinstance(self._java_obj, str)
+ assert isinstance(self._java_obj, RemoteModelRef)
params = serialize_ml_params(self, session.client)
- return ConnectDataFrame(
- TransformerRelation(
- child=dataset._plan, name=self._java_obj, ml_params=params, is_model=True
- ),
- session,
+ plan = TransformerRelation(
+ child=dataset._plan,
+ name=self._java_obj.ref_id,
+ ml_params=params,
+ is_model=True,
)
elif isinstance(self, Transformer):
from pyspark.ml.connect.proto import TransformerRelation
assert isinstance(self._java_obj, str)
params = serialize_ml_params(self, session.client)
- return ConnectDataFrame(
- TransformerRelation(
- child=dataset._plan,
- name=self._java_obj,
- ml_params=params,
- uid=self.uid,
- is_model=False,
- ),
- session,
+ plan = TransformerRelation(
+ child=dataset._plan,
+ name=self._java_obj,
+ ml_params=params,
+ uid=self.uid,
+ is_model=False,
)
+
else:
raise RuntimeError(f"Unsupported {self}")
+
+ # To delay the GC of the model, keep a reference to the source transformer
+ # in the transformed dataframe and all its descendants.
+ # For this case:
+ #
+ # def fit_transform(df):
+ # model = estimator.fit(df)
+ # return model.transform(df)
+ #
+ # output = fit_transform(df)
+ #
+ plan.__source_transformer__ = self # type: ignore[attr-defined]
+ return ConnectDataFrame(plan=plan, session=session)
else:
return f(self, dataset)
@@ -226,11 +280,20 @@ def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any:
from pyspark.sql.connect.session import SparkSession
from pyspark.ml.connect.util import _extract_id_methods
from pyspark.ml.connect.serialize import serialize, deserialize
+ from pyspark.ml.wrapper import JavaModel
session = SparkSession.getActiveSession()
assert session is not None
- assert isinstance(self._java_obj, str)
- methods, obj_ref = _extract_id_methods(self._java_obj)
+ if self._java_obj == ML_CONNECT_HELPER_ID:
+ obj_id = ML_CONNECT_HELPER_ID
+ else:
+ if isinstance(self, JavaModel):
+ assert isinstance(self._java_obj, RemoteModelRef)
+ obj_id = self._java_obj.ref_id
+ else:
+ # model summary
+ obj_id = self._java_obj # type: ignore
+ methods, obj_ref = _extract_id_methods(obj_id)
methods.append(pb2.Fetch.Method(method=name, args=serialize(session.client, *args)))
command = pb2.Command()
command.ml_command.fetch.CopyFrom(
@@ -281,20 +344,8 @@ def wrapped(self: "JavaWrapper") -> Any:
except Exception:
return
- if in_remote:
- # Delete the model if possible
- # model_id = self._java_obj
- # del_remote_cache(model_id)
- #
- # Above codes delete the model from the ml cache eagerly, and may cause
- # NPE in the server side in the case of 'fit_transform':
- #
- # def fit_transform(df):
- # model = estimator.fit(df)
- # return model.transform(df)
- #
- # output = fit_transform(df)
- # output.show()
+ if in_remote and isinstance(self._java_obj, RemoteModelRef):
+ self._java_obj.release_ref()
return
else:
return f(self)
@@ -1104,3 +1155,53 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
return f(*args, **kwargs)
return cast(FuncT, wrapped)
+
+
+_SPARKML_TEMP_DFS_PATH = "SPARKML_TEMP_DFS_PATH"
+
+
+def _get_temp_dfs_path() -> Optional[str]:
+ return os.environ.get(_SPARKML_TEMP_DFS_PATH)
+
+
+def _remove_dfs_dir(path: str, spark_session: "SparkSession") -> None:
+ from pyspark.ml.wrapper import JavaWrapper
+ from pyspark.sql import is_remote
+
+ if is_remote():
+ from pyspark.ml.util import ML_CONNECT_HELPER_ID
+
+ helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+ helper._call_java("handleOverwrite", path, True)
+ else:
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
+ wrapper = JavaWrapper(_java_obj)
+ wrapper._call_java("handleOverwrite", path, True, spark_session._jsparkSession)
+
+
+@contextmanager
+def _cache_spark_dataset(
+ dataset: "DataFrame",
+ storageLevel: "StorageLevel" = StorageLevel.MEMORY_AND_DISK_DESER,
+ enable: bool = True,
+) -> Iterator[Any]:
+ if not enable:
+ yield dataset
+ return
+
+ spark_session = dataset._session
+ tmp_dfs_path = os.environ.get(_SPARKML_TEMP_DFS_PATH)
+
+ if tmp_dfs_path:
+ tmp_cache_path = os.path.join(tmp_dfs_path, uuid.uuid4().hex)
+ dataset.write.save(tmp_cache_path)
+ try:
+ yield spark_session.read.load(tmp_cache_path)
+ finally:
+ _remove_dfs_dir(tmp_cache_path, spark_session)
+ else:
+ dataset.persist(storageLevel)
+ try:
+ yield dataset
+ finally:
+ dataset.unpersist()
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index f88045e718a55..b8d86e9eab3b1 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -356,9 +356,15 @@ def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP":
if extra is None:
extra = dict()
that = super(JavaParams, self).copy(extra)
- if self._java_obj is not None and not isinstance(self._java_obj, str):
- that._java_obj = self._java_obj.copy(self._empty_java_param_map())
- that._transfer_params_to_java()
+ if self._java_obj is not None:
+ from pyspark.ml.util import RemoteModelRef
+
+ if isinstance(self._java_obj, RemoteModelRef):
+ that._java_obj = self._java_obj
+ self._java_obj.add_ref()
+ elif not isinstance(self._java_obj, str):
+ that._java_obj = self._java_obj.copy(self._empty_java_param_map())
+ that._transfer_params_to_java()
return that
@try_remote_intercept
@@ -452,6 +458,10 @@ def __init__(self, java_model: Optional["JavaObject"] = None):
other ML classes).
"""
super(JavaModel, self).__init__(java_model)
+ if is_remote() and java_model is not None:
+ from pyspark.ml.util import RemoteModelRef
+
+ assert isinstance(java_model, RemoteModelRef)
if java_model is not None and not is_remote():
# SPARK-10931: This is a temporary fix to allow models to own params
# from estimators. Eventually, these params should be in models through
diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py
index 81358185a5608..01e23214d662d 100644
--- a/python/pyspark/pandas/base.py
+++ b/python/pyspark/pandas/base.py
@@ -1191,7 +1191,6 @@ def _shift(
return self._with_new_scol(col, field=self._internal.data_fields[0].copy(nullable=True))
# TODO: Update Documentation for Bins Parameter when its supported
- # TODO(SPARK-51287): Enable s.index.value_counts() tests
def value_counts(
self,
normalize: bool = False,
@@ -1324,7 +1323,7 @@ def value_counts(
('falcon', 'length')],
)
- >>> s.index.value_counts().sort_index() # doctest: +SKIP
+ >>> s.index.value_counts().sort_index()
(cow, length) 1
(cow, weight) 2
(falcon, length) 2
@@ -1332,7 +1331,7 @@ def value_counts(
(lama, weight) 3
Name: count, dtype: int64
- >>> s.index.value_counts(normalize=True).sort_index() # doctest: +SKIP
+ >>> s.index.value_counts(normalize=True).sort_index()
(cow, length) 0.111111
(cow, weight) 0.222222
(falcon, length) 0.222222
diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py
index d31bc1f48d112..2ab1260eff695 100644
--- a/python/pyspark/pandas/namespace.py
+++ b/python/pyspark/pandas/namespace.py
@@ -36,6 +36,7 @@
from datetime import tzinfo
from functools import reduce
from io import BytesIO
+import pickle
import json
import warnings
@@ -65,6 +66,7 @@
StringType,
DateType,
StructType,
+ StructField,
DataType,
)
from pyspark.sql.dataframe import DataFrame as PySparkDataFrame
@@ -82,9 +84,11 @@
validate_axis,
log_advice,
)
+from pyspark.pandas.config import get_option
from pyspark.pandas.frame import DataFrame, _reduce_spark_multi
from pyspark.pandas.internal import (
InternalFrame,
+ InternalField,
DEFAULT_SERIES_NAME,
HIDDEN_COLUMNS,
SPARK_INDEX_NAME_FORMAT,
@@ -1128,7 +1132,9 @@ def read_excel(
"""
def pd_read_excel(
- io_or_bin: Any, sn: Union[str, int, List[Union[str, int]], None]
+ io_or_bin: Any,
+ sn: Union[str, int, List[Union[str, int]], None],
+ nr: Optional[int] = None,
) -> pd.DataFrame:
return pd.read_excel(
io=BytesIO(io_or_bin) if isinstance(io_or_bin, (bytes, bytearray)) else io_or_bin,
@@ -1143,7 +1149,7 @@ def pd_read_excel(
true_values=true_values,
false_values=false_values,
skiprows=skiprows,
- nrows=nrows,
+ nrows=nr,
na_values=na_values,
keep_default_na=keep_default_na,
verbose=verbose,
@@ -1155,18 +1161,9 @@ def pd_read_excel(
**kwds,
)
- if isinstance(io, str):
- # 'binaryFile' format is available since Spark 3.0.0.
- binaries = default_session().read.format("binaryFile").load(io).select("content").head(2)
- io_or_bin = binaries[0][0]
- single_file = len(binaries) == 1
- else:
- io_or_bin = io
- single_file = True
-
- pdf_or_psers = pd_read_excel(io_or_bin, sn=sheet_name)
-
- if single_file:
+ if not isinstance(io, str):
+ # When io is not a path, always need to load all data to python side
+ pdf_or_psers = pd_read_excel(io, sn=sheet_name, nr=nrows)
if isinstance(pdf_or_psers, dict):
return {
sn: cast(Union[DataFrame, Series], from_pandas(pdf_or_pser))
@@ -1174,52 +1171,89 @@ def pd_read_excel(
}
else:
return cast(Union[DataFrame, Series], from_pandas(pdf_or_psers))
- else:
-
- def read_excel_on_spark(
- pdf_or_pser: Union[pd.DataFrame, pd.Series],
- sn: Union[str, int, List[Union[str, int]], None],
- ) -> Union[DataFrame, Series]:
- if isinstance(pdf_or_pser, pd.Series):
- pdf = pdf_or_pser.to_frame()
- else:
- pdf = pdf_or_pser
- psdf = cast(DataFrame, from_pandas(pdf))
- return_schema = force_decimal_precision_scale(
- as_nullable_spark_type(psdf._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema)
- )
+ spark = default_session()
+
+ # Collect the first #nr rows from the first file
+ nr = get_option("compute.max_rows", 1000)
+ if nrows is not None and nrows < nr:
+ nr = nrows
+
+ def sample_data(pdf: pd.DataFrame) -> pd.DataFrame:
+ raw_data = BytesIO(pdf.content[0])
+ pdf_or_dict = pd_read_excel(raw_data, sn=sheet_name, nr=nr)
+ return pd.DataFrame({"sampled": [pickle.dumps(pdf_or_dict)]})
+
+ # 'binaryFile' format is available since Spark 3.0.0.
+ sampled = (
+ spark.read.format("binaryFile")
+ .load(io)
+ .select("content")
+ .limit(1) # Read at most 1 file
+ .mapInPandas(func=lambda iterator: map(sample_data, iterator), schema="sampled BINARY")
+ .head()
+ )
+ sampled = pickle.loads(sampled[0])
+
+ def read_excel_on_spark(
+ pdf_or_pser: Union[pd.DataFrame, pd.Series],
+ sn: Union[str, int, List[Union[str, int]], None],
+ ) -> Union[DataFrame, Series]:
+ if isinstance(pdf_or_pser, pd.Series):
+ pdf = pdf_or_pser.to_frame()
+ else:
+ pdf = pdf_or_pser
- def output_func(pdf: pd.DataFrame) -> pd.DataFrame:
- pdf = pd.concat([pd_read_excel(bin, sn=sn) for bin in pdf[pdf.columns[0]]])
-
- reset_index = pdf.reset_index()
- for name, col in reset_index.items():
- dt = col.dtype
- if is_datetime64_dtype(dt) or isinstance(dt, pd.DatetimeTZDtype):
- continue
- reset_index[name] = col.replace({np.nan: None})
- pdf = reset_index
-
- # Just positionally map the column names to given schema's.
- return pdf.rename(columns=dict(zip(pdf.columns, return_schema.names)))
-
- sdf = (
- default_session()
- .read.format("binaryFile")
- .load(io)
- .select("content")
- .mapInPandas(lambda iterator: map(output_func, iterator), schema=return_schema)
- )
+ psdf = cast(DataFrame, from_pandas(pdf))
- return DataFrame(psdf._internal.with_new_sdf(sdf))
+ raw_schema = psdf._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema
+ index_scol_names = psdf._internal.index_spark_column_names
+ nullable_fields = []
+ for field in raw_schema.fields:
+ if field.name in index_scol_names:
+ nullable_fields.append(field)
+ else:
+ nullable_fields.append(
+ StructField(
+ field.name,
+ as_nullable_spark_type(field.dataType),
+ nullable=True,
+ metadata=field.metadata,
+ )
+ )
+ nullable_schema = StructType(nullable_fields)
+ return_schema = force_decimal_precision_scale(nullable_schema)
+
+ return_data_fields: Optional[List[InternalField]] = None
+ if psdf._internal.data_fields is not None:
+ return_data_fields = [f.normalize_spark_type() for f in psdf._internal.data_fields]
+
+ def output_func(pdf: pd.DataFrame) -> pd.DataFrame:
+ pdf = pd.concat([pd_read_excel(bin, sn=sn, nr=nrows) for bin in pdf[pdf.columns[0]]])
+
+ reset_index = pdf.reset_index()
+ for name, col in reset_index.items():
+ dt = col.dtype
+ if is_datetime64_dtype(dt) or isinstance(dt, pd.DatetimeTZDtype):
+ continue
+ reset_index[name] = col.replace({np.nan: None})
+ pdf = reset_index
+
+ # Just positionally map the column names to given schema's.
+ return pdf.rename(columns=dict(zip(pdf.columns, return_schema.names)))
+
+ sdf = (
+ spark.read.format("binaryFile")
+ .load(io)
+ .select("content")
+ .mapInPandas(lambda iterator: map(output_func, iterator), schema=return_schema)
+ )
+ return DataFrame(psdf._internal.with_new_sdf(sdf, data_fields=return_data_fields))
- if isinstance(pdf_or_psers, dict):
- return {
- sn: read_excel_on_spark(pdf_or_pser, sn) for sn, pdf_or_pser in pdf_or_psers.items()
- }
- else:
- return read_excel_on_spark(pdf_or_psers, sheet_name)
+ if isinstance(sampled, dict):
+ return {sn: read_excel_on_spark(pdf_or_pser, sn) for sn, pdf_or_pser in sampled.items()}
+ else:
+ return read_excel_on_spark(cast(Union[pd.DataFrame, pd.Series], sampled), sheet_name)
def read_html(
diff --git a/python/pyspark/pandas/tests/io/test_dataframe_spark_io.py b/python/pyspark/pandas/tests/io/test_dataframe_spark_io.py
index b8225b10f1c79..af77ea8aa64ff 100644
--- a/python/pyspark/pandas/tests/io/test_dataframe_spark_io.py
+++ b/python/pyspark/pandas/tests/io/test_dataframe_spark_io.py
@@ -23,6 +23,7 @@
from pyspark import pandas as ps
from pyspark.loose_version import LooseVersion
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.testing.utils import have_openpyxl, openpyxl_requirement_message
class DataFrameSparkIOTestsMixin:
@@ -253,8 +254,7 @@ def test_spark_io(self):
expected_idx.sort_values(by="f").to_spark().toPandas(),
)
- # TODO(SPARK-40353): re-enabling the `test_read_excel`.
- @unittest.skip("openpyxl")
+ @unittest.skipIf(not have_openpyxl, openpyxl_requirement_message)
def test_read_excel(self):
with self.temp_dir() as tmp:
path1 = "{}/file1.xlsx".format(tmp)
@@ -266,21 +266,22 @@ def test_read_excel(self):
pd.read_excel(open(path1, "rb"), index_col=0),
)
self.assert_eq(
- ps.read_excel(open(path1, "rb"), index_col=0, squeeze=True),
- pd.read_excel(open(path1, "rb"), index_col=0, squeeze=True),
+ ps.read_excel(open(path1, "rb"), index_col=0),
+ pd.read_excel(open(path1, "rb"), index_col=0),
)
self.assert_eq(ps.read_excel(path1), pd.read_excel(path1))
self.assert_eq(ps.read_excel(path1, index_col=0), pd.read_excel(path1, index_col=0))
self.assert_eq(
- ps.read_excel(path1, index_col=0, squeeze=True),
- pd.read_excel(path1, index_col=0, squeeze=True),
+ ps.read_excel(path1, index_col=0),
+ pd.read_excel(path1, index_col=0),
)
self.assert_eq(ps.read_excel(tmp), pd.read_excel(path1))
path2 = "{}/file2.xlsx".format(tmp)
self.test_pdf[["i32"]].to_excel(path2)
+ print(ps.read_excel(tmp, index_col=0).sort_index())
self.assert_eq(
ps.read_excel(tmp, index_col=0).sort_index(),
pd.concat(
@@ -288,11 +289,11 @@ def test_read_excel(self):
).sort_index(),
)
self.assert_eq(
- ps.read_excel(tmp, index_col=0, squeeze=True).sort_index(),
+ ps.read_excel(tmp, index_col=0).sort_index(),
pd.concat(
[
- pd.read_excel(path1, index_col=0, squeeze=True),
- pd.read_excel(path2, index_col=0, squeeze=True),
+ pd.read_excel(path1, index_col=0),
+ pd.read_excel(path2, index_col=0),
]
).sort_index(),
)
@@ -306,21 +307,12 @@ def test_read_excel(self):
sheet_names = [["Sheet_name_1", "Sheet_name_2"], None]
pdfs1 = pd.read_excel(open(path1, "rb"), sheet_name=None, index_col=0)
- pdfs1_squeezed = pd.read_excel(
- open(path1, "rb"), sheet_name=None, index_col=0, squeeze=True
- )
for sheet_name in sheet_names:
psdfs = ps.read_excel(open(path1, "rb"), sheet_name=sheet_name, index_col=0)
self.assert_eq(psdfs["Sheet_name_1"], pdfs1["Sheet_name_1"])
self.assert_eq(psdfs["Sheet_name_2"], pdfs1["Sheet_name_2"])
- psdfs = ps.read_excel(
- open(path1, "rb"), sheet_name=sheet_name, index_col=0, squeeze=True
- )
- self.assert_eq(psdfs["Sheet_name_1"], pdfs1_squeezed["Sheet_name_1"])
- self.assert_eq(psdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"])
-
self.assert_eq(
ps.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"),
pdfs1["Sheet_name_2"],
@@ -331,30 +323,17 @@ def test_read_excel(self):
self.assert_eq(psdfs["Sheet_name_1"], pdfs1["Sheet_name_1"])
self.assert_eq(psdfs["Sheet_name_2"], pdfs1["Sheet_name_2"])
- psdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True)
- self.assert_eq(psdfs["Sheet_name_1"], pdfs1_squeezed["Sheet_name_1"])
- self.assert_eq(psdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"])
-
path2 = "{}/file2.xlsx".format(tmp)
with pd.ExcelWriter(path2) as writer:
self.test_pdf.to_excel(writer, sheet_name="Sheet_name_1")
self.test_pdf[["i32"]].to_excel(writer, sheet_name="Sheet_name_2")
pdfs2 = pd.read_excel(path2, sheet_name=None, index_col=0)
- pdfs2_squeezed = pd.read_excel(path2, sheet_name=None, index_col=0, squeeze=True)
self.assert_eq(
ps.read_excel(tmp, sheet_name="Sheet_name_2", index_col=0).sort_index(),
pd.concat([pdfs1["Sheet_name_2"], pdfs2["Sheet_name_2"]]).sort_index(),
)
- self.assert_eq(
- ps.read_excel(
- tmp, sheet_name="Sheet_name_2", index_col=0, squeeze=True
- ).sort_index(),
- pd.concat(
- [pdfs1_squeezed["Sheet_name_2"], pdfs2_squeezed["Sheet_name_2"]]
- ).sort_index(),
- )
for sheet_name in sheet_names:
psdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0)
@@ -367,19 +346,26 @@ def test_read_excel(self):
pd.concat([pdfs1["Sheet_name_2"], pdfs2["Sheet_name_2"]]).sort_index(),
)
- psdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True)
- self.assert_eq(
- psdfs["Sheet_name_1"].sort_index(),
- pd.concat(
- [pdfs1_squeezed["Sheet_name_1"], pdfs2_squeezed["Sheet_name_1"]]
- ).sort_index(),
- )
- self.assert_eq(
- psdfs["Sheet_name_2"].sort_index(),
- pd.concat(
- [pdfs1_squeezed["Sheet_name_2"], pdfs2_squeezed["Sheet_name_2"]]
- ).sort_index(),
- )
+ @unittest.skipIf(not have_openpyxl, openpyxl_requirement_message)
+ def test_read_large_excel(self):
+ n = 20000
+ pdf = pd.DataFrame(
+ {
+ "i32": np.arange(n, dtype=np.int32) % 3,
+ "i64": np.arange(n, dtype=np.int64) % 5,
+ "f": np.arange(n, dtype=np.float64),
+ "bhello": np.random.choice(["hello", "yo", "people"], size=n).astype("O"),
+ },
+ columns=["i32", "i64", "f", "bhello"],
+ index=np.random.rand(n),
+ )
+
+ with self.temp_dir() as tmp:
+ path = "{}/large_file.xlsx".format(tmp)
+ pdf.to_excel(path)
+
+ self.assert_eq(ps.read_excel(path), pd.read_excel(path))
+ self.assert_eq(ps.read_excel(path, nrows=10), pd.read_excel(path, nrows=10))
def test_read_orc(self):
with self.temp_dir() as tmp:
diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py
index 161f8ba4bb7ab..fef65bcb5d54e 100644
--- a/python/pyspark/sql/classic/column.py
+++ b/python/pyspark/sql/classic/column.py
@@ -474,6 +474,13 @@ def substr(
return Column(jc)
def isin(self, *cols: Any) -> ParentColumn:
+ from pyspark.sql.classic.dataframe import DataFrame
+
+ if len(cols) == 1 and isinstance(cols[0], DataFrame):
+ df: DataFrame = cols[0]
+ jc = self._jc.isin(df._jdf)
+ return Column(jc)
+
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cast(Tuple, cols[0])
cols = cast(
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index a055e44564952..db02cc80dbeda 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -883,6 +883,9 @@ def isin(self, *cols: Any) -> "Column":
.. versionchanged:: 3.4.0
Supports Spark Connect.
+ .. versionchanged:: 4.1.0
+ Also takes a single :class:`DataFrame` to be used as IN subquery.
+
Parameters
----------
cols : Any
@@ -900,7 +903,7 @@ def isin(self, *cols: Any) -> "Column":
Example 1: Filter rows with names in the specified values
- >>> df[df.name.isin("Bob", "Mike")].show()
+ >>> df[df.name.isin("Bob", "Mike")].orderBy("age").show()
+---+----+
|age|name|
+---+----+
@@ -925,6 +928,26 @@ def isin(self, *cols: Any) -> "Column":
+---+----+
| 8|Mike|
+---+----+
+
+ Example 4: Take a :class:`DataFrame` and work as IN subquery
+
+ >>> df.where(df.age.isin(spark.range(6))).orderBy("age").show()
+ +---+-----+
+ |age| name|
+ +---+-----+
+ | 2|Alice|
+ | 5| Bob|
+ +---+-----+
+
+ Example 5: Multiple values for IN subquery
+
+ >>> from pyspark.sql.functions import lit, struct
+ >>> df.where(struct(df.age, df.name).isin(spark.range(6).select("id", lit("Bob")))).show()
+ +---+----+
+ |age|name|
+ +---+----+
+ | 5| Bob|
+ +---+----+
"""
...
diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py
index 360f391de6c1c..ca9bdd9b6f0c4 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -666,11 +666,11 @@ def __init__(
elif user_id is not None:
self._user_id = user_id
else:
- self._user_id = os.getenv("USER", None)
+ self._user_id = os.getenv("SPARK_USER", os.getenv("USER", None))
self._channel = self._builder.toChannel()
self._closed = False
- self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
+ self._internal_stub = grpc_lib.SparkConnectServiceStub(self._channel)
self._artifact_manager = ArtifactManager(
self._user_id, self._session_id, self._channel, self._builder.metadata()
)
@@ -686,7 +686,20 @@ def __init__(
self._progress_handlers: List[ProgressHandler] = []
# cleanup ml cache if possible
- atexit.register(self._cleanup_ml)
+ atexit.register(self._cleanup_ml_cache)
+
+ @property
+ def _stub(self) -> grpc_lib.SparkConnectServiceStub:
+ if self.is_closed:
+ raise SparkConnectException(
+ errorClass="NO_ACTIVE_SESSION", messageParameters=dict()
+ ) from None
+ return self._internal_stub
+
+ # For testing only.
+ @_stub.setter
+ def _stub(self, value: grpc_lib.SparkConnectServiceStub) -> None:
+ self._internal_stub = value
def register_progress_handler(self, handler: ProgressHandler) -> None:
"""
@@ -945,7 +958,7 @@ def to_pandas(
# DataFrame, as it may fail with a segmentation fault. Instead, we create an empty pandas
# DataFrame manually with the correct schema.
if table.num_rows == 0:
- pdf = pd.DataFrame(columns=schema.names)
+ pdf = pd.DataFrame(columns=schema.names, index=range(0))
else:
# Rename columns to avoid duplicated column names.
renamed_table = table.rename_columns([f"col_{i}" for i in range(table.num_columns)])
@@ -1215,7 +1228,9 @@ def token(self) -> Optional[str]:
"""
return self._builder.token
- def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
+ def _execute_plan_request_with_metadata(
+ self, operation_id: Optional[str] = None
+ ) -> pb2.ExecutePlanRequest:
req = pb2.ExecutePlanRequest(
session_id=self._session_id,
client_type=self._builder.userAgent,
@@ -1225,6 +1240,15 @@ def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
req.client_observed_server_side_session_id = self._server_session_id
if self._user_id:
req.user_context.user_id = self._user_id
+ if operation_id is not None:
+ try:
+ uuid.UUID(operation_id, version=4)
+ except ValueError as ve:
+ raise PySparkValueError(
+ errorClass="INVALID_OPERATION_UUID_ID",
+ messageParameters={"arg_name": "operation_id", "origin": str(ve)},
+ )
+ req.operation_id = operation_id
return req
def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
@@ -1796,11 +1820,6 @@ def _handle_error(self, error: Exception) -> NoReturn:
self.thread_local.inside_error_handling = True
if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
- elif isinstance(error, ValueError):
- if "Cannot invoke RPC" in str(error) and "closed" in str(error):
- raise SparkConnectException(
- errorClass="NO_ACTIVE_SESSION", messageParameters=dict()
- ) from None
raise error
finally:
self.thread_local.inside_error_handling = False
@@ -1876,9 +1895,12 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
status.message,
self._fetch_enriched_error(info),
self._display_server_stack_trace(),
+ status.code,
) from None
- raise SparkConnectGrpcException(status.message) from None
+ raise SparkConnectGrpcException(
+ message=status.message, grpc_status_code=status.code
+ ) from None
else:
raise SparkConnectGrpcException(str(rpc_error)) from None
@@ -1959,26 +1981,52 @@ def add_ml_cache(self, cache_id: str) -> None:
self.thread_local.ml_caches.add(cache_id)
def remove_ml_cache(self, cache_id: str) -> None:
- if not hasattr(self.thread_local, "ml_caches"):
- self.thread_local.ml_caches = set()
-
- if cache_id in self.thread_local.ml_caches:
- self._delete_ml_cache(cache_id)
-
- def _delete_ml_cache(self, cache_id: str) -> None:
+ deleted = self._delete_ml_cache([cache_id])
+ # TODO: Fix the code: change thread-local `ml_caches` to global `ml_caches`.
+ if hasattr(self.thread_local, "ml_caches"):
+ if cache_id in self.thread_local.ml_caches:
+ for obj_id in deleted:
+ self.thread_local.ml_caches.remove(obj_id)
+
+ def _delete_ml_cache(self, cache_ids: List[str]) -> List[str]:
# try best to delete the cache
try:
- command = pb2.Command()
- command.ml_command.delete.obj_ref.CopyFrom(pb2.ObjectRef(id=cache_id))
- self.execute_command(command)
+ if len(cache_ids) > 0:
+ command = pb2.Command()
+ command.ml_command.delete.obj_refs.extend(
+ [pb2.ObjectRef(id=cache_id) for cache_id in cache_ids]
+ )
+ (_, properties, _) = self.execute_command(command)
+
+ assert properties is not None
+
+ if properties is not None and "ml_command_result" in properties:
+ ml_command_result = properties["ml_command_result"]
+ deleted = ml_command_result.operator_info.obj_ref.id.split(",")
+ return cast(List[str], deleted)
+ return []
except Exception:
- pass
+ return []
- def _cleanup_ml(self) -> None:
- if not hasattr(self.thread_local, "ml_caches"):
- self.thread_local.ml_caches = set()
+ def _cleanup_ml_cache(self) -> None:
+ if hasattr(self.thread_local, "ml_caches"):
+ try:
+ command = pb2.Command()
+ command.ml_command.clean_cache.SetInParent()
+ self.execute_command(command)
+ self.thread_local.ml_caches.clear()
+ except Exception:
+ pass
+
+ def _get_ml_cache_info(self) -> List[str]:
+ if hasattr(self.thread_local, "ml_caches"):
+ command = pb2.Command()
+ command.ml_command.get_cache_info.SetInParent()
+ (_, properties, _) = self.execute_command(command)
+
+ assert properties is not None
- self.disable_reattachable_execute()
- # Todo add a pattern to delete all model in one command
- for model_id in self.thread_local.ml_caches:
- self._delete_ml_cache(model_id)
+ if properties is not None and "ml_command_result" in properties:
+ ml_command_result = properties["ml_command_result"]
+ return [item.string for item in ml_command_result.param.array.elements]
+ return []
diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py
index 91b7aa125920d..a827150cfb76a 100644
--- a/python/pyspark/sql/connect/client/reattach.py
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -213,7 +213,9 @@ def target() -> None:
with self._lock:
if self._release_thread_pool_instance is not None:
- self._release_thread_pool.submit(target)
+ thread_pool = self._release_thread_pool
+ if not thread_pool._shutdown:
+ thread_pool.submit(target)
def _release_all(self) -> None:
"""
@@ -238,7 +240,9 @@ def target() -> None:
with self._lock:
if self._release_thread_pool_instance is not None:
- self._release_thread_pool.submit(target)
+ thread_pool = self._release_thread_pool
+ if not thread_pool._shutdown:
+ thread_pool.submit(target)
self._result_complete = True
def _call_iter(self, iter_fun: Callable) -> Any:
@@ -261,14 +265,22 @@ def _call_iter(self, iter_fun: Callable) -> Any:
return iter_fun()
except grpc.RpcError as e:
status = rpc_status.from_call(cast(grpc.Call, e))
- if status is not None and (
- "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message
- or "INVALID_HANDLE.SESSION_NOT_FOUND" in status.message
- ):
+ unexpected_error = next(
+ (
+ error
+ for error in [
+ "INVALID_HANDLE.OPERATION_NOT_FOUND",
+ "INVALID_HANDLE.SESSION_NOT_FOUND",
+ ]
+ if status is not None and error in status.message
+ ),
+ None,
+ )
+ if unexpected_error is not None:
if self._last_returned_response_id is not None:
raise PySparkRuntimeError(
errorClass="RESPONSE_ALREADY_RECEIVED",
- messageParameters={},
+ messageParameters={"error_type": unexpected_error},
)
# Try a new ExecutePlan, and throw upstream for retry.
self._iterator = iter(
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index 15d943175850c..d6ed62ba4a523 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -47,6 +47,7 @@
LiteralExpression,
CaseWhen,
SortOrder,
+ SubqueryExpression,
CastExpression,
WindowExpression,
WithField,
@@ -461,6 +462,18 @@ def outer(self) -> ParentColumn:
return Column(self._expr)
def isin(self, *cols: Any) -> ParentColumn:
+ from pyspark.sql.connect.dataframe import DataFrame
+
+ if len(cols) == 1 and isinstance(cols[0], DataFrame):
+ if isinstance(self._expr, UnresolvedFunction) and self._expr._name == "struct":
+ values = self._expr.children
+ else:
+ values = [self._expr]
+
+ return Column(
+ SubqueryExpression(cols[0]._plan, subquery_type="in", in_subquery_values=values)
+ )
+
if len(cols) == 1 and isinstance(cols[0], (list, set)):
_cols = list(cols[0])
else:
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index e5b10be41963d..872770ee22911 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -1241,9 +1241,10 @@ def __init__(
partition_spec: Optional[Sequence["Expression"]] = None,
order_spec: Optional[Sequence["SortOrder"]] = None,
with_single_partition: Optional[bool] = None,
+ in_subquery_values: Optional[Sequence["Expression"]] = None,
) -> None:
assert isinstance(subquery_type, str)
- assert subquery_type in ("scalar", "exists", "table_arg")
+ assert subquery_type in ("scalar", "exists", "table_arg", "in")
super().__init__()
self._plan = plan
@@ -1251,6 +1252,7 @@ def __init__(
self._partition_spec = partition_spec or []
self._order_spec = order_spec or []
self._with_single_partition = with_single_partition
+ self._in_subquery_values = in_subquery_values or []
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
@@ -1276,17 +1278,25 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
)
if self._with_single_partition is not None:
table_arg_options.with_single_partition = self._with_single_partition
+ elif self._subquery_type == "in":
+ expr.subquery_expression.subquery_type = proto.SubqueryExpression.SUBQUERY_TYPE_IN
+ expr.subquery_expression.in_subquery_values.extend(
+ [expr.to_plan(session) for expr in self._in_subquery_values]
+ )
return expr
def __repr__(self) -> str:
repr_parts = [f"plan={self._plan}", f"type={self._subquery_type}"]
- if self._partition_spec:
- repr_parts.append(f"partition_spec={self._partition_spec}")
- if self._order_spec:
- repr_parts.append(f"order_spec={self._order_spec}")
- if self._with_single_partition is not None:
- repr_parts.append(f"with_single_partition={self._with_single_partition}")
+ if self._subquery_type == "table_arg":
+ if self._partition_spec:
+ repr_parts.append(f"partition_spec={self._partition_spec}")
+ if self._order_spec:
+ repr_parts.append(f"order_spec={self._order_spec}")
+ if self._with_single_partition is not None:
+ repr_parts.append(f"with_single_partition={self._with_single_partition}")
+ elif self._subquery_type == "in":
+ repr_parts.append(f"values={self._in_subquery_values}")
return f"SubqueryExpression({', '.join(repr_parts)})"
diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py
index 92af6d2eaba2f..f49495ef05bd2 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -37,7 +37,7 @@
ValuesView,
cast,
)
-import random
+import random as py_random
import sys
import numpy as np
@@ -411,17 +411,20 @@ def rand(seed: Optional[int] = None) -> Column:
if seed is not None:
return _invoke_function("rand", lit(seed))
else:
- return _invoke_function("rand", lit(random.randint(0, sys.maxsize)))
+ return _invoke_function("rand", lit(py_random.randint(0, sys.maxsize)))
rand.__doc__ = pysparkfuncs.rand.__doc__
+random = rand
+
+
def randn(seed: Optional[int] = None) -> Column:
if seed is not None:
return _invoke_function("randn", lit(seed))
else:
- return _invoke_function("randn", lit(random.randint(0, sys.maxsize)))
+ return _invoke_function("randn", lit(py_random.randint(0, sys.maxsize)))
randn.__doc__ = pysparkfuncs.randn.__doc__
@@ -1009,7 +1012,7 @@ def uniform(
) -> Column:
if seed is None:
return _invoke_function_over_columns(
- "uniform", lit(min), lit(max), lit(random.randint(0, sys.maxsize))
+ "uniform", lit(min), lit(max), lit(py_random.randint(0, sys.maxsize))
)
else:
return _invoke_function_over_columns("uniform", lit(min), lit(max), lit(seed))
@@ -1198,7 +1201,7 @@ def count_min_sketch(
confidence: Union[Column, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
- _seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed)
+ _seed = lit(py_random.randint(0, sys.maxsize)) if seed is None else lit(seed)
return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed)
@@ -1270,7 +1273,7 @@ def mode(col: "ColumnOrName", deterministic: bool = False) -> Column:
def percentile(
col: "ColumnOrName",
- percentage: Union[Column, float, Sequence[float], Tuple[float]],
+ percentage: Union[Column, float, Sequence[float], Tuple[float, ...]],
frequency: Union[Column, int] = 1,
) -> Column:
if not isinstance(frequency, (int, Column)):
@@ -1291,7 +1294,7 @@ def percentile(
def percentile_approx(
col: "ColumnOrName",
- percentage: Union[Column, float, Sequence[float], Tuple[float]],
+ percentage: Union[Column, float, Sequence[float], Tuple[float, ...]],
accuracy: Union[Column, int] = 10000,
) -> Column:
percentage = lit(list(percentage)) if isinstance(percentage, (list, tuple)) else lit(percentage)
@@ -1303,7 +1306,7 @@ def percentile_approx(
def approx_percentile(
col: "ColumnOrName",
- percentage: Union[Column, float, Sequence[float], Tuple[float]],
+ percentage: Union[Column, float, Sequence[float], Tuple[float, ...]],
accuracy: Union[Column, int] = 10000,
) -> Column:
percentage = lit(list(percentage)) if isinstance(percentage, (list, tuple)) else lit(percentage)
@@ -2282,7 +2285,7 @@ def schema_of_xml(xml: Union[str, Column], options: Optional[Mapping[str, str]]
def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column:
- _seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed)
+ _seed = lit(py_random.randint(0, sys.maxsize)) if seed is None else lit(seed)
return _invoke_function("shuffle", _to_col(col), _seed)
@@ -2706,7 +2709,7 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column:
def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column:
if seed is None:
return _invoke_function_over_columns(
- "randstr", lit(length), lit(random.randint(0, sys.maxsize))
+ "randstr", lit(length), lit(py_random.randint(0, sys.maxsize))
)
else:
return _invoke_function_over_columns("randstr", lit(length), lit(seed))
@@ -2999,6 +3002,13 @@ def character_length(str: "ColumnOrName") -> Column:
character_length.__doc__ = pysparkfuncs.character_length.__doc__
+def chr(n: "ColumnOrName") -> Column:
+ return _invoke_function_over_columns("chr", n)
+
+
+chr.__doc__ = pysparkfuncs.chr.__doc__
+
+
def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column:
return _invoke_function_over_columns("contains", left, right)
@@ -4013,6 +4023,10 @@ def session_user() -> Column:
session_user.__doc__ = pysparkfuncs.session_user.__doc__
+def uuid() -> Column:
+ return _invoke_function("uuid", lit(py_random.randint(0, sys.maxsize)))
+
+
def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None) -> Column:
errMsg = _enum_to_value(errMsg)
if errMsg is None:
diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py
index 79cf930fd8ee1..ef0384cf8252a 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -373,9 +373,11 @@ def transformWithStateInPandas(
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
- from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasUdfUtils
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPySparkUdfUtils,
+ )
- udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode)
+ udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, timeMode)
if initialState is None:
udf_obj = UserDefinedFunction(
udf_util.transformWithStateUDF,
@@ -412,6 +414,58 @@ def transformWithStateInPandas(
transformWithStateInPandas.__doc__ = PySparkGroupedData.transformWithStateInPandas.__doc__
+ def transformWithState(
+ self,
+ statefulProcessor: StatefulProcessor,
+ outputStructType: Union[StructType, str],
+ outputMode: str,
+ timeMode: str,
+ initialState: Optional["GroupedData"] = None,
+ eventTimeColumnName: str = "",
+ ) -> "DataFrame":
+ from pyspark.sql.connect.udf import UserDefinedFunction
+ from pyspark.sql.connect.dataframe import DataFrame
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPySparkUdfUtils,
+ )
+
+ udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, timeMode)
+ if initialState is None:
+ udf_obj = UserDefinedFunction(
+ udf_util.transformWithStateUDF,
+ returnType=outputStructType,
+ evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
+ )
+ initial_state_plan = None
+ initial_state_grouping_cols = None
+ else:
+ self._df._check_same_session(initialState._df)
+ udf_obj = UserDefinedFunction(
+ udf_util.transformWithStateWithInitStateUDF,
+ returnType=outputStructType,
+ evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
+ )
+ initial_state_plan = initialState._df._plan
+ initial_state_grouping_cols = initialState._grouping_cols
+
+ return DataFrame(
+ plan.TransformWithStateInPySpark(
+ child=self._df._plan,
+ grouping_cols=self._grouping_cols,
+ function=udf_obj,
+ output_schema=outputStructType,
+ output_mode=outputMode,
+ time_mode=timeMode,
+ event_time_col_name=eventTimeColumnName,
+ cols=self._df.columns,
+ initial_state_plan=initial_state_plan,
+ initial_state_grouping_cols=initial_state_grouping_cols,
+ ),
+ session=self._df._session,
+ )
+
+ transformWithState.__doc__ = PySparkGroupedData.transformWithState.__doc__
+
def applyInArrow(
self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index c4c7a6a636307..c5b6f5430d6d5 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -2546,8 +2546,8 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return self._with_relations(plan, session)
-class TransformWithStateInPandas(LogicalPlan):
- """Logical plan object for a TransformWithStateInPandas."""
+class BaseTransformWithStateInPySpark(LogicalPlan):
+ """Base implementation of logical plan object for a TransformWithStateIn(PySpark/Pandas)."""
def __init__(
self,
@@ -2600,7 +2600,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
[c.to_plan(session) for c in self._initial_state_grouping_cols]
)
- # fill in transformWithStateInPandas related fields
+ # fill in transformWithStateInPySpark/Pandas related fields
tws_info = proto.TransformWithStateInfo()
tws_info.time_mode = self._time_mode
tws_info.event_time_column_name = self._event_time_col_name
@@ -2608,12 +2608,25 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan.group_map.transform_with_state_info.CopyFrom(tws_info)
- # wrap transformWithStateInPandasUdf in a function
+ # wrap transformWithStateInPySparkUdf in a function
plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))
return self._with_relations(plan, session)
+class TransformWithStateInPySpark(BaseTransformWithStateInPySpark):
+ """Logical plan object for a TransformWithStateInPySpark."""
+
+ pass
+
+
+# Retaining this to avoid breaking backward compatibility.
+class TransformWithStateInPandas(BaseTransformWithStateInPySpark):
+ """Logical plan object for a TransformWithStateInPandas."""
+
+ pass
+
+
class PythonUDTF:
"""Represents a Python user-defined table function."""
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index f5fdd162a7083..0cec23f4857df 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -40,7 +40,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xf3\x34\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12 \x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x42\n\rcall_function\x18\x10 \x01(\x0b\x32\x1b.spark.connect.CallFunctionH\x00R\x0c\x63\x61llFunction\x12\x64\n\x19named_argument_expression\x18\x11 \x01(\x0b\x32&.spark.connect.NamedArgumentExpressionH\x00R\x17namedArgumentExpression\x12?\n\x0cmerge_action\x18\x13 \x01(\x0b\x32\x1a.spark.connect.MergeActionH\x00R\x0bmergeAction\x12g\n\x1atyped_aggregate_expression\x18\x14 \x01(\x0b\x32\'.spark.connect.TypedAggregateExpressionH\x00R\x18typedAggregateExpression\x12T\n\x13subquery_expression\x18\x15 \x01(\x0b\x32!.spark.connect.SubqueryExpressionH\x00R\x12subqueryExpression\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\xbb\x02\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStr\x12\x44\n\teval_mode\x18\x04 \x01(\x0e\x32\'.spark.connect.Expression.Cast.EvalModeR\x08\x65valMode"b\n\x08\x45valMode\x12\x19\n\x15\x45VAL_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10\x45VAL_MODE_LEGACY\x10\x01\x12\x12\n\x0e\x45VAL_MODE_ANSI\x10\x02\x12\x11\n\rEVAL_MODE_TRY\x10\x03\x42\x0e\n\x0c\x63\x61st_to_type\x1a\xc1\x0f\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x39\n\x03map\x18\x17 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x42\n\x06struct\x18\x18 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x61\n\x11specialized_array\x18\x19 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.SpecializedArrayH\x00R\x10specializedArray\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x82\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xe3\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12\x35\n\x04keys\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x04keys\x12\x39\n\x06values\x18\x04 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\x81\x01\n\x06Struct\x12\x38\n\x0bstruct_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\nstructType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xc0\x02\n\x10SpecializedArray\x12,\n\x05\x62ools\x18\x01 \x01(\x0b\x32\x14.spark.connect.BoolsH\x00R\x05\x62ools\x12)\n\x04ints\x18\x02 \x01(\x0b\x32\x13.spark.connect.IntsH\x00R\x04ints\x12,\n\x05longs\x18\x03 \x01(\x0b\x32\x14.spark.connect.LongsH\x00R\x05longs\x12/\n\x06\x66loats\x18\x04 \x01(\x0b\x32\x15.spark.connect.FloatsH\x00R\x06\x66loats\x12\x32\n\x07\x64oubles\x18\x05 \x01(\x0b\x32\x16.spark.connect.DoublesH\x00R\x07\x64oubles\x12\x32\n\x07strings\x18\x06 \x01(\x0b\x32\x16.spark.connect.StringsH\x00R\x07stringsB\x0c\n\nvalue_typeB\x0e\n\x0cliteral_type\x1a\xba\x01\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12\x31\n\x12is_metadata_column\x18\x03 \x01(\x08H\x01R\x10isMetadataColumn\x88\x01\x01\x42\n\n\x08_plan_idB\x15\n\x13_is_metadata_column\x1a\x82\x02\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x12$\n\x0bis_internal\x18\x05 \x01(\x08H\x00R\nisInternal\x88\x01\x01\x42\x0e\n\x0c_is_internal\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a|\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x01R\x06planId\x88\x01\x01\x42\x12\n\x10_unparsed_targetB\n\n\x08_plan_id\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"A\n\x10\x45xpressionCommon\x12-\n\x06origin\x18\x01 \x01(\x0b\x32\x15.spark.connect.OriginR\x06origin"\x8d\x03\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdf\x12\x1f\n\x0bis_distinct\x18\x07 \x01(\x08R\nisDistinctB\n\n\x08\x66unction"\xcc\x01\n\tPythonUDF\x12\x38\n\x0boutput_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer\x12/\n\x13\x61\x64\x64itional_includes\x18\x05 \x03(\tR\x12\x61\x64\x64itionalIncludes"\xd6\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable\x12\x1c\n\taggregate\x18\x05 \x01(\x08R\taggregate"\x95\x01\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12=\n\x0boutput_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_type"c\n\x18TypedAggregateExpression\x12G\n\x10scalar_scala_udf\x18\x01 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFR\x0escalarScalaUdf"l\n\x0c\x43\x61llFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\\\n\x17NamedArgumentExpression\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"\x80\x04\n\x0bMergeAction\x12\x46\n\x0b\x61\x63tion_type\x18\x01 \x01(\x0e\x32%.spark.connect.MergeAction.ActionTypeR\nactionType\x12<\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\tcondition\x88\x01\x01\x12G\n\x0b\x61ssignments\x18\x03 \x03(\x0b\x32%.spark.connect.MergeAction.AssignmentR\x0b\x61ssignments\x1aj\n\nAssignment\x12+\n\x03key\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"\xa7\x01\n\nActionType\x12\x17\n\x13\x41\x43TION_TYPE_INVALID\x10\x00\x12\x16\n\x12\x41\x43TION_TYPE_DELETE\x10\x01\x12\x16\n\x12\x41\x43TION_TYPE_INSERT\x10\x02\x12\x1b\n\x17\x41\x43TION_TYPE_INSERT_STAR\x10\x03\x12\x16\n\x12\x41\x43TION_TYPE_UPDATE\x10\x04\x12\x1b\n\x17\x41\x43TION_TYPE_UPDATE_STAR\x10\x05\x42\x0c\n\n_condition"\xe5\x04\n\x12SubqueryExpression\x12\x17\n\x07plan_id\x18\x01 \x01(\x03R\x06planId\x12S\n\rsubquery_type\x18\x02 \x01(\x0e\x32..spark.connect.SubqueryExpression.SubqueryTypeR\x0csubqueryType\x12\x62\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.spark.connect.SubqueryExpression.TableArgOptionsH\x00R\x0ftableArgOptions\x88\x01\x01\x1a\xea\x01\n\x0fTableArgOptions\x12@\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12\x37\n\x15with_single_partition\x18\x03 \x01(\x08H\x00R\x13withSinglePartition\x88\x01\x01\x42\x18\n\x16_with_single_partition"z\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x42\x14\n\x12_table_arg_optionsB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xf3\x34\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12 \x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x42\n\rcall_function\x18\x10 \x01(\x0b\x32\x1b.spark.connect.CallFunctionH\x00R\x0c\x63\x61llFunction\x12\x64\n\x19named_argument_expression\x18\x11 \x01(\x0b\x32&.spark.connect.NamedArgumentExpressionH\x00R\x17namedArgumentExpression\x12?\n\x0cmerge_action\x18\x13 \x01(\x0b\x32\x1a.spark.connect.MergeActionH\x00R\x0bmergeAction\x12g\n\x1atyped_aggregate_expression\x18\x14 \x01(\x0b\x32\'.spark.connect.TypedAggregateExpressionH\x00R\x18typedAggregateExpression\x12T\n\x13subquery_expression\x18\x15 \x01(\x0b\x32!.spark.connect.SubqueryExpressionH\x00R\x12subqueryExpression\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\xbb\x02\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStr\x12\x44\n\teval_mode\x18\x04 \x01(\x0e\x32\'.spark.connect.Expression.Cast.EvalModeR\x08\x65valMode"b\n\x08\x45valMode\x12\x19\n\x15\x45VAL_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10\x45VAL_MODE_LEGACY\x10\x01\x12\x12\n\x0e\x45VAL_MODE_ANSI\x10\x02\x12\x11\n\rEVAL_MODE_TRY\x10\x03\x42\x0e\n\x0c\x63\x61st_to_type\x1a\xc1\x0f\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x39\n\x03map\x18\x17 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x42\n\x06struct\x18\x18 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x61\n\x11specialized_array\x18\x19 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.SpecializedArrayH\x00R\x10specializedArray\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x82\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xe3\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12\x35\n\x04keys\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x04keys\x12\x39\n\x06values\x18\x04 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\x81\x01\n\x06Struct\x12\x38\n\x0bstruct_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\nstructType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xc0\x02\n\x10SpecializedArray\x12,\n\x05\x62ools\x18\x01 \x01(\x0b\x32\x14.spark.connect.BoolsH\x00R\x05\x62ools\x12)\n\x04ints\x18\x02 \x01(\x0b\x32\x13.spark.connect.IntsH\x00R\x04ints\x12,\n\x05longs\x18\x03 \x01(\x0b\x32\x14.spark.connect.LongsH\x00R\x05longs\x12/\n\x06\x66loats\x18\x04 \x01(\x0b\x32\x15.spark.connect.FloatsH\x00R\x06\x66loats\x12\x32\n\x07\x64oubles\x18\x05 \x01(\x0b\x32\x16.spark.connect.DoublesH\x00R\x07\x64oubles\x12\x32\n\x07strings\x18\x06 \x01(\x0b\x32\x16.spark.connect.StringsH\x00R\x07stringsB\x0c\n\nvalue_typeB\x0e\n\x0cliteral_type\x1a\xba\x01\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12\x31\n\x12is_metadata_column\x18\x03 \x01(\x08H\x01R\x10isMetadataColumn\x88\x01\x01\x42\n\n\x08_plan_idB\x15\n\x13_is_metadata_column\x1a\x82\x02\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x12$\n\x0bis_internal\x18\x05 \x01(\x08H\x00R\nisInternal\x88\x01\x01\x42\x0e\n\x0c_is_internal\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a|\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x01R\x06planId\x88\x01\x01\x42\x12\n\x10_unparsed_targetB\n\n\x08_plan_id\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"A\n\x10\x45xpressionCommon\x12-\n\x06origin\x18\x01 \x01(\x0b\x32\x15.spark.connect.OriginR\x06origin"\x8d\x03\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdf\x12\x1f\n\x0bis_distinct\x18\x07 \x01(\x08R\nisDistinctB\n\n\x08\x66unction"\xcc\x01\n\tPythonUDF\x12\x38\n\x0boutput_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer\x12/\n\x13\x61\x64\x64itional_includes\x18\x05 \x03(\tR\x12\x61\x64\x64itionalIncludes"\xd6\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable\x12\x1c\n\taggregate\x18\x05 \x01(\x08R\taggregate"\x95\x01\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12=\n\x0boutput_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_type"c\n\x18TypedAggregateExpression\x12G\n\x10scalar_scala_udf\x18\x01 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFR\x0escalarScalaUdf"l\n\x0c\x43\x61llFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\\\n\x17NamedArgumentExpression\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"\x80\x04\n\x0bMergeAction\x12\x46\n\x0b\x61\x63tion_type\x18\x01 \x01(\x0e\x32%.spark.connect.MergeAction.ActionTypeR\nactionType\x12<\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\tcondition\x88\x01\x01\x12G\n\x0b\x61ssignments\x18\x03 \x03(\x0b\x32%.spark.connect.MergeAction.AssignmentR\x0b\x61ssignments\x1aj\n\nAssignment\x12+\n\x03key\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value"\xa7\x01\n\nActionType\x12\x17\n\x13\x41\x43TION_TYPE_INVALID\x10\x00\x12\x16\n\x12\x41\x43TION_TYPE_DELETE\x10\x01\x12\x16\n\x12\x41\x43TION_TYPE_INSERT\x10\x02\x12\x1b\n\x17\x41\x43TION_TYPE_INSERT_STAR\x10\x03\x12\x16\n\x12\x41\x43TION_TYPE_UPDATE\x10\x04\x12\x1b\n\x17\x41\x43TION_TYPE_UPDATE_STAR\x10\x05\x42\x0c\n\n_condition"\xc5\x05\n\x12SubqueryExpression\x12\x17\n\x07plan_id\x18\x01 \x01(\x03R\x06planId\x12S\n\rsubquery_type\x18\x02 \x01(\x0e\x32..spark.connect.SubqueryExpression.SubqueryTypeR\x0csubqueryType\x12\x62\n\x11table_arg_options\x18\x03 \x01(\x0b\x32\x31.spark.connect.SubqueryExpression.TableArgOptionsH\x00R\x0ftableArgOptions\x88\x01\x01\x12G\n\x12in_subquery_values\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x10inSubqueryValues\x1a\xea\x01\n\x0fTableArgOptions\x12@\n\x0epartition_spec\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12\x37\n\x15with_single_partition\x18\x03 \x01(\x08H\x00R\x13withSinglePartition\x88\x01\x01\x42\x18\n\x16_with_single_partition"\x90\x01\n\x0cSubqueryType\x12\x19\n\x15SUBQUERY_TYPE_UNKNOWN\x10\x00\x12\x18\n\x14SUBQUERY_TYPE_SCALAR\x10\x01\x12\x18\n\x14SUBQUERY_TYPE_EXISTS\x10\x02\x12\x1b\n\x17SUBQUERY_TYPE_TABLE_ARG\x10\x03\x12\x14\n\x10SUBQUERY_TYPE_IN\x10\x04\x42\x14\n\x12_table_arg_optionsB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
)
_globals = globals()
@@ -130,9 +130,9 @@
_globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8586
_globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8753
_globals["_SUBQUERYEXPRESSION"]._serialized_start = 8770
- _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9383
- _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9003
- _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9237
- _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9239
- _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9361
+ _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9479
+ _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9076
+ _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9310
+ _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9313
+ _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9457
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index f6aada59a2d83..25fc04c0319e6 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1940,12 +1940,14 @@ class SubqueryExpression(google.protobuf.message.Message):
SUBQUERY_TYPE_SCALAR: SubqueryExpression._SubqueryType.ValueType # 1
SUBQUERY_TYPE_EXISTS: SubqueryExpression._SubqueryType.ValueType # 2
SUBQUERY_TYPE_TABLE_ARG: SubqueryExpression._SubqueryType.ValueType # 3
+ SUBQUERY_TYPE_IN: SubqueryExpression._SubqueryType.ValueType # 4
class SubqueryType(_SubqueryType, metaclass=_SubqueryTypeEnumTypeWrapper): ...
SUBQUERY_TYPE_UNKNOWN: SubqueryExpression.SubqueryType.ValueType # 0
SUBQUERY_TYPE_SCALAR: SubqueryExpression.SubqueryType.ValueType # 1
SUBQUERY_TYPE_EXISTS: SubqueryExpression.SubqueryType.ValueType # 2
SUBQUERY_TYPE_TABLE_ARG: SubqueryExpression.SubqueryType.ValueType # 3
+ SUBQUERY_TYPE_IN: SubqueryExpression.SubqueryType.ValueType # 4
class TableArgOptions(google.protobuf.message.Message):
"""Nested message for table argument options."""
@@ -2010,6 +2012,7 @@ class SubqueryExpression(google.protobuf.message.Message):
PLAN_ID_FIELD_NUMBER: builtins.int
SUBQUERY_TYPE_FIELD_NUMBER: builtins.int
TABLE_ARG_OPTIONS_FIELD_NUMBER: builtins.int
+ IN_SUBQUERY_VALUES_FIELD_NUMBER: builtins.int
plan_id: builtins.int
"""(Required) The ID of the corresponding connect plan."""
subquery_type: global___SubqueryExpression.SubqueryType.ValueType
@@ -2017,12 +2020,18 @@ class SubqueryExpression(google.protobuf.message.Message):
@property
def table_arg_options(self) -> global___SubqueryExpression.TableArgOptions:
"""(Optional) Options specific to table arguments."""
+ @property
+ def in_subquery_values(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]:
+ """(Optional) IN subquery values."""
def __init__(
self,
*,
plan_id: builtins.int = ...,
subquery_type: global___SubqueryExpression.SubqueryType.ValueType = ...,
table_arg_options: global___SubqueryExpression.TableArgOptions | None = ...,
+ in_subquery_values: collections.abc.Iterable[global___Expression] | None = ...,
) -> None: ...
def HasField(
self,
@@ -2035,6 +2044,8 @@ class SubqueryExpression(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"_table_arg_options",
b"_table_arg_options",
+ "in_subquery_values",
+ b"in_subquery_values",
"plan_id",
b"plan_id",
"subquery_type",
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py b/python/pyspark/sql/connect/proto/ml_pb2.py
index 666cb1efdd2b4..31fa3dd5d0ec5 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xfb\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1a;\n\x06\x44\x65lete\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_paramsB\t\n\x07\x63ommand"\x93\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\xc3\x01\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
+ b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb2\x0b\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1a=\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_paramsB\t\n\x07\x63ommand"\x93\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\xc3\x01\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
)
_globals = globals()
@@ -54,21 +54,25 @@
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
_globals["_MLCOMMAND"]._serialized_start = 137
- _globals["_MLCOMMAND"]._serialized_end = 1412
- _globals["_MLCOMMAND_FIT"]._serialized_start = 480
- _globals["_MLCOMMAND_FIT"]._serialized_end = 658
- _globals["_MLCOMMAND_DELETE"]._serialized_start = 660
- _globals["_MLCOMMAND_DELETE"]._serialized_end = 719
- _globals["_MLCOMMAND_WRITE"]._serialized_start = 722
- _globals["_MLCOMMAND_WRITE"]._serialized_end = 1132
- _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1034
- _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1092
- _globals["_MLCOMMAND_READ"]._serialized_start = 1134
- _globals["_MLCOMMAND_READ"]._serialized_end = 1215
- _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1218
- _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1401
- _globals["_MLCOMMANDRESULT"]._serialized_start = 1415
- _globals["_MLCOMMANDRESULT"]._serialized_end = 1818
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1608
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1803
+ _globals["_MLCOMMAND"]._serialized_end = 1595
+ _globals["_MLCOMMAND_FIT"]._serialized_start = 631
+ _globals["_MLCOMMAND_FIT"]._serialized_end = 809
+ _globals["_MLCOMMAND_DELETE"]._serialized_start = 811
+ _globals["_MLCOMMAND_DELETE"]._serialized_end = 872
+ _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 874
+ _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 886
+ _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 888
+ _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 902
+ _globals["_MLCOMMAND_WRITE"]._serialized_start = 905
+ _globals["_MLCOMMAND_WRITE"]._serialized_end = 1315
+ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1217
+ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1275
+ _globals["_MLCOMMAND_READ"]._serialized_start = 1317
+ _globals["_MLCOMMAND_READ"]._serialized_end = 1398
+ _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1401
+ _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1584
+ _globals["_MLCOMMANDRESULT"]._serialized_start = 1598
+ _globals["_MLCOMMANDRESULT"]._serialized_end = 2001
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1791
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1986
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index 3a1e9155d71dc..9f6f4c1516d8d 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -111,25 +111,45 @@ class MlCommand(google.protobuf.message.Message):
) -> typing_extensions.Literal["params"] | None: ...
class Delete(google.protobuf.message.Message):
- """Command to delete the cached object which could be a model
+ """Command to delete the cached objects which could be a model
or summary evaluated by a model
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- OBJ_REF_FIELD_NUMBER: builtins.int
+ OBJ_REFS_FIELD_NUMBER: builtins.int
@property
- def obj_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ...
+ def obj_refs(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.ml_common_pb2.ObjectRef
+ ]: ...
def __init__(
self,
*,
- obj_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None = ...,
+ obj_refs: collections.abc.Iterable[pyspark.sql.connect.proto.ml_common_pb2.ObjectRef]
+ | None = ...,
) -> None: ...
- def HasField(
- self, field_name: typing_extensions.Literal["obj_ref", b"obj_ref"]
- ) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["obj_ref", b"obj_ref"]
+ self, field_name: typing_extensions.Literal["obj_refs", b"obj_refs"]
+ ) -> None: ...
+
+ class CleanCache(google.protobuf.message.Message):
+ """Force to clean up all the ML cached objects"""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ def __init__(
+ self,
+ ) -> None: ...
+
+ class GetCacheInfo(google.protobuf.message.Message):
+ """Get the information of all the ML cached objects"""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ def __init__(
+ self,
) -> None: ...
class Write(google.protobuf.message.Message):
@@ -328,6 +348,8 @@ class MlCommand(google.protobuf.message.Message):
WRITE_FIELD_NUMBER: builtins.int
READ_FIELD_NUMBER: builtins.int
EVALUATE_FIELD_NUMBER: builtins.int
+ CLEAN_CACHE_FIELD_NUMBER: builtins.int
+ GET_CACHE_INFO_FIELD_NUMBER: builtins.int
@property
def fit(self) -> global___MlCommand.Fit: ...
@property
@@ -340,6 +362,10 @@ class MlCommand(google.protobuf.message.Message):
def read(self) -> global___MlCommand.Read: ...
@property
def evaluate(self) -> global___MlCommand.Evaluate: ...
+ @property
+ def clean_cache(self) -> global___MlCommand.CleanCache: ...
+ @property
+ def get_cache_info(self) -> global___MlCommand.GetCacheInfo: ...
def __init__(
self,
*,
@@ -349,10 +375,14 @@ class MlCommand(google.protobuf.message.Message):
write: global___MlCommand.Write | None = ...,
read: global___MlCommand.Read | None = ...,
evaluate: global___MlCommand.Evaluate | None = ...,
+ clean_cache: global___MlCommand.CleanCache | None = ...,
+ get_cache_info: global___MlCommand.GetCacheInfo | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
+ "clean_cache",
+ b"clean_cache",
"command",
b"command",
"delete",
@@ -363,6 +393,8 @@ class MlCommand(google.protobuf.message.Message):
b"fetch",
"fit",
b"fit",
+ "get_cache_info",
+ b"get_cache_info",
"read",
b"read",
"write",
@@ -372,6 +404,8 @@ class MlCommand(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "clean_cache",
+ b"clean_cache",
"command",
b"command",
"delete",
@@ -382,6 +416,8 @@ class MlCommand(google.protobuf.message.Message):
b"fetch",
"fit",
b"fit",
+ "get_cache_info",
+ b"get_cache_info",
"read",
b"read",
"write",
@@ -391,7 +427,10 @@ class MlCommand(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["command", b"command"]
) -> (
- typing_extensions.Literal["fit", "fetch", "delete", "write", "read", "evaluate"] | None
+ typing_extensions.Literal[
+ "fit", "fetch", "delete", "write", "read", "evaluate", "clean_cache", "get_cache_info"
+ ]
+ | None
): ...
global___MlCommand = MlCommand
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 79677262eb167..efa9ce7c2c435 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -1074,11 +1074,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
overwrite_conf["spark.connect.grpc.binding.port"] = "0"
origin_remote = os.environ.get("SPARK_REMOTE", None)
+ origin_connect_mode = os.environ.get("SPARK_CONNECT_MODE", None)
try:
+ # So SparkSubmit thinks no remote is set in order to
+ # start the regular PySpark session.
if origin_remote is not None:
- # So SparkSubmit thinks no remote is set in order to
- # start the regular PySpark session.
del os.environ["SPARK_REMOTE"]
+ os.environ["SPARK_CONNECT_MODE"] = "0"
# The regular PySpark session is registered as an active session
# so would not be garbage-collected.
@@ -1096,6 +1098,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
+ if origin_connect_mode is not None:
+ os.environ["SPARK_CONNECT_MODE"] = origin_connect_mode
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index b471769ad4285..b819634adb5a6 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -91,9 +91,11 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def]
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
# There could be a long time between each micro batch.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index a7a5066ca0d77..2c6ce87159948 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -105,9 +105,11 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
# There could be a long time between each listener event.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/connect/tvf.py b/python/pyspark/sql/connect/tvf.py
index 2fca610a5fe3a..cf94fdb64915f 100644
--- a/python/pyspark/sql/connect/tvf.py
+++ b/python/pyspark/sql/connect/tvf.py
@@ -14,21 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Optional
+from typing import Optional, TYPE_CHECKING
from pyspark.errors import PySparkValueError
-from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.dataframe import DataFrame
-from pyspark.sql.connect.functions.builtin import _to_col
-from pyspark.sql.connect.plan import UnresolvedTableValuedFunction
-from pyspark.sql.connect.session import SparkSession
from pyspark.sql.tvf import TableValuedFunction as PySparkTableValuedFunction
+if TYPE_CHECKING:
+ from pyspark.sql.connect.column import Column
+ from pyspark.sql.connect.dataframe import DataFrame
+ from pyspark.sql.connect.session import SparkSession
+
class TableValuedFunction:
__doc__ = PySparkTableValuedFunction.__doc__
- def __init__(self, sparkSession: SparkSession):
+ def __init__(self, sparkSession: "SparkSession"):
self._sparkSession = sparkSession
def range(
@@ -37,34 +37,34 @@ def range(
end: Optional[int] = None,
step: int = 1,
numPartitions: Optional[int] = None,
- ) -> DataFrame:
+ ) -> "DataFrame":
return self._sparkSession.range( # type: ignore[return-value]
start, end, step, numPartitions
)
range.__doc__ = PySparkTableValuedFunction.range.__doc__
- def explode(self, collection: Column) -> DataFrame:
+ def explode(self, collection: "Column") -> "DataFrame":
return self._fn("explode", collection)
explode.__doc__ = PySparkTableValuedFunction.explode.__doc__
- def explode_outer(self, collection: Column) -> DataFrame:
+ def explode_outer(self, collection: "Column") -> "DataFrame":
return self._fn("explode_outer", collection)
explode_outer.__doc__ = PySparkTableValuedFunction.explode_outer.__doc__
- def inline(self, input: Column) -> DataFrame:
+ def inline(self, input: "Column") -> "DataFrame":
return self._fn("inline", input)
inline.__doc__ = PySparkTableValuedFunction.inline.__doc__
- def inline_outer(self, input: Column) -> DataFrame:
+ def inline_outer(self, input: "Column") -> "DataFrame":
return self._fn("inline_outer", input)
inline_outer.__doc__ = PySparkTableValuedFunction.inline_outer.__doc__
- def json_tuple(self, input: Column, *fields: Column) -> DataFrame:
+ def json_tuple(self, input: "Column", *fields: "Column") -> "DataFrame":
if len(fields) == 0:
raise PySparkValueError(
errorClass="CANNOT_BE_EMPTY",
@@ -74,42 +74,46 @@ def json_tuple(self, input: Column, *fields: Column) -> DataFrame:
json_tuple.__doc__ = PySparkTableValuedFunction.json_tuple.__doc__
- def posexplode(self, collection: Column) -> DataFrame:
+ def posexplode(self, collection: "Column") -> "DataFrame":
return self._fn("posexplode", collection)
posexplode.__doc__ = PySparkTableValuedFunction.posexplode.__doc__
- def posexplode_outer(self, collection: Column) -> DataFrame:
+ def posexplode_outer(self, collection: "Column") -> "DataFrame":
return self._fn("posexplode_outer", collection)
posexplode_outer.__doc__ = PySparkTableValuedFunction.posexplode_outer.__doc__
- def stack(self, n: Column, *fields: Column) -> DataFrame:
+ def stack(self, n: "Column", *fields: "Column") -> "DataFrame":
return self._fn("stack", n, *fields)
stack.__doc__ = PySparkTableValuedFunction.stack.__doc__
- def collations(self) -> DataFrame:
+ def collations(self) -> "DataFrame":
return self._fn("collations")
collations.__doc__ = PySparkTableValuedFunction.collations.__doc__
- def sql_keywords(self) -> DataFrame:
+ def sql_keywords(self) -> "DataFrame":
return self._fn("sql_keywords")
sql_keywords.__doc__ = PySparkTableValuedFunction.sql_keywords.__doc__
- def variant_explode(self, input: Column) -> DataFrame:
+ def variant_explode(self, input: "Column") -> "DataFrame":
return self._fn("variant_explode", input)
variant_explode.__doc__ = PySparkTableValuedFunction.variant_explode.__doc__
- def variant_explode_outer(self, input: Column) -> DataFrame:
+ def variant_explode_outer(self, input: "Column") -> "DataFrame":
return self._fn("variant_explode_outer", input)
variant_explode_outer.__doc__ = PySparkTableValuedFunction.variant_explode_outer.__doc__
- def _fn(self, name: str, *args: Column) -> DataFrame:
+ def _fn(self, name: str, *args: "Column") -> "DataFrame":
+ from pyspark.sql.connect.dataframe import DataFrame
+ from pyspark.sql.connect.plan import UnresolvedTableValuedFunction
+ from pyspark.sql.connect.functions.builtin import _to_col
+
return DataFrame(
UnresolvedTableValuedFunction(name, [_to_col(arg) for arg in args]), self._sparkSession
)
@@ -117,8 +121,19 @@ def _fn(self, name: str, *args: Column) -> DataFrame:
def _test() -> None:
import os
- import doctest
import sys
+
+ if os.environ.get("PYTHON_GIL", "?") == "0":
+ print("Not supported in no-GIL mode", file=sys.stderr)
+ sys.exit(0)
+
+ from pyspark.testing import should_test_connect
+
+ if not should_test_connect:
+ print("Skipping pyspark.sql.connect.tvf doctests", file=sys.stderr)
+ sys.exit(0)
+
+ import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.tvf
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index 6045e441222de..cd87a3ef74eaf 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -21,10 +21,9 @@
check_dependencies(__name__)
+import warnings
import sys
import functools
-import warnings
-from inspect import getfullargspec
from typing import cast, Callable, Any, List, TYPE_CHECKING, Optional, Union
from pyspark.util import PythonEvalType
@@ -41,7 +40,7 @@
UDFRegistration as PySparkUDFRegistration,
UserDefinedFunction as PySparkUserDefinedFunction,
)
-from pyspark.sql.utils import has_arrow
+from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version
from pyspark.errors import PySparkTypeError, PySparkRuntimeError
if TYPE_CHECKING:
@@ -80,26 +79,17 @@ def _create_py_udf(
eval_type: int = PythonEvalType.SQL_BATCHED_UDF
- if is_arrow_enabled and not has_arrow:
- is_arrow_enabled = False
- warnings.warn(
- "Arrow optimization failed to enable because PyArrow is not installed. "
- "Falling back to a non-Arrow-optimized UDF.",
- RuntimeWarning,
- )
-
if is_arrow_enabled:
+ eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
try:
- is_func_with_args = len(getfullargspec(f).args) > 0
- except TypeError:
- is_func_with_args = False
- if is_func_with_args:
- eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
- else:
+ require_minimum_pandas_version()
+ require_minimum_pyarrow_version()
+ except ImportError:
+ is_arrow_enabled = False
warnings.warn(
- "Arrow optimization for Python UDFs cannot be enabled for functions"
- " without arguments.",
- UserWarning,
+ "Arrow optimization failed to enable because PyArrow or Pandas is not installed. "
+ "Falling back to a non-Arrow-optimized UDF.",
+ RuntimeWarning,
)
return _create_udf(f, returnType, eval_type)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index cd06b3fa3eeb6..c00c3f484232b 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2127,15 +2127,16 @@ def sampleBy(
Examples
--------
>>> from pyspark.sql.functions import col
- >>> dataset = spark.range(0, 100).select((col("id") % 3).alias("key"))
+ >>> dataset = spark.range(0, 100, 1, 5).select((col("id") % 3).alias("key"))
>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
>>> sampled.groupBy("key").count().orderBy("key").show()
+---+-----+
|key|count|
+---+-----+
- | 0| 3|
- | 1| 6|
+ | 0| 4|
+ | 1| 9|
+---+-----+
+
>>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count()
33
"""
@@ -5935,7 +5936,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame":
@dispatch_df_method
def toDF(self, *cols: str) -> "DataFrame":
- """Returns a new :class:`DataFrame` that with new specified column names
+ """Returns a new :class:`DataFrame` with new specified column names
.. versionadded:: 1.6.0
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index 651e84e84390e..ff7178b3b1af4 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -16,7 +16,20 @@
#
from abc import ABC, abstractmethod
from collections import UserDict
-from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union, TYPE_CHECKING
+from dataclasses import dataclass
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ TYPE_CHECKING,
+)
from pyspark.sql import Row
from pyspark.sql.types import StructType
@@ -38,6 +51,8 @@
"InputPartition",
"SimpleDataSourceStreamReader",
"WriterCommitMessage",
+ "Filter",
+ "EqualTo",
]
@@ -234,6 +249,215 @@ def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
)
+ColumnPath = Tuple[str, ...]
+"""
+A tuple of strings representing a column reference.
+
+For example, `("a", "b", "c")` represents the column `a.b.c`.
+
+.. versionadded: 4.1.0
+"""
+
+
+@dataclass(frozen=True)
+class Filter(ABC):
+ """
+ The base class for filters used for filter pushdown.
+
+ .. versionadded: 4.1.0
+
+ Notes
+ -----
+ Column references are represented as a tuple of strings. For example:
+
+ +----------------+----------------------+
+ | Column | Representation |
+ +----------------+----------------------+
+ | `col1` | `("col1",)` |
+ | `a.b.c` | `("a", "b", "c")` |
+ +----------------+----------------------+
+
+ Literal values are represented as Python objects of types such as
+ `int`, `float`, `str`, `bool`, `datetime`, etc.
+ See `Data Types `_
+ for more information about how values are represented in Python.
+
+ Examples
+ --------
+ Supported filters
+
+ +---------------------+--------------------------------------------+
+ | SQL filter | Representation |
+ +---------------------+--------------------------------------------+
+ | `a.b.c = 1` | `EqualTo(("a", "b", "c"), 1)` |
+ | `a = 1` | `EqualTo(("a",), 1)` |
+ | `a = 'hi'` | `EqualTo(("a",), "hi")` |
+ | `a = array(1, 2)` | `EqualTo(("a",), [1, 2])` |
+ | `a` | `EqualTo(("a",), True)` |
+ | `not a` | `Not(EqualTo(("a",), True))` |
+ | `a <> 1` | `Not(EqualTo(("a",), 1))` |
+ | `a > 1` | `GreaterThan(("a",), 1)` |
+ | `a >= 1` | `GreaterThanOrEqual(("a",), 1)` |
+ | `a < 1` | `LessThan(("a",), 1)` |
+ | `a <= 1` | `LessThanOrEqual(("a",), 1)` |
+ | `a in (1, 2, 3)` | `In(("a",), (1, 2, 3))` |
+ | `a is null` | `IsNull(("a",))` |
+ | `a is not null` | `IsNotNull(("a",))` |
+ | `a like 'abc%'` | `StringStartsWith(("a",), "abc")` |
+ | `a like '%abc'` | `StringEndsWith(("a",), "abc")` |
+ | `a like '%abc%'` | `StringContains(("a",), "abc")` |
+ +---------------------+--------------------------------------------+
+
+ Unsupported filters
+ - `a = b`
+ - `f(a, b) = 1`
+ - `a % 2 = 1`
+ - `a[0] = 1`
+ - `a < 0 or a > 1`
+ - `a like 'c%c%'`
+ - `a ilike 'hi'`
+ - `a = 'hi' collate zh`
+ """
+
+
+@dataclass(frozen=True)
+class EqualTo(Filter):
+ """
+ A filter that evaluates to `True` iff the column evaluates to a value
+ equal to `value`.
+ """
+
+ attribute: ColumnPath
+ value: Any
+
+
+@dataclass(frozen=True)
+class EqualNullSafe(Filter):
+ """
+ Performs equality comparison, similar to EqualTo. However, this differs from EqualTo
+ in that it returns `true` (rather than NULL) if both inputs are NULL, and `false`
+ (rather than NULL) if one of the input is NULL and the other is not NULL.
+ """
+
+ attribute: ColumnPath
+ value: Any
+
+
+@dataclass(frozen=True)
+class GreaterThan(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to a value
+ greater than `value`.
+ """
+
+ attribute: ColumnPath
+ value: Any
+
+
+@dataclass(frozen=True)
+class GreaterThanOrEqual(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to a value
+ greater than or equal to `value`.
+ """
+
+ attribute: ColumnPath
+ value: Any
+
+
+@dataclass(frozen=True)
+class LessThan(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to a value
+ less than `value`.
+ """
+
+ attribute: ColumnPath
+ value: Any
+
+
+@dataclass(frozen=True)
+class LessThanOrEqual(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to a value
+ less than or equal to `value`.
+ """
+
+ attribute: ColumnPath
+ value: Any
+
+
+@dataclass(frozen=True)
+class In(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to one of the values
+ in the array.
+ """
+
+ attribute: ColumnPath
+ value: Tuple[Any, ...]
+
+
+@dataclass(frozen=True)
+class IsNull(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to null.
+ """
+
+ attribute: ColumnPath
+
+
+@dataclass(frozen=True)
+class IsNotNull(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to a non-null value.
+ """
+
+ attribute: ColumnPath
+
+
+@dataclass(frozen=True)
+class Not(Filter):
+ """
+ A filter that evaluates to `True` iff `child` is evaluated to `False`.
+ """
+
+ child: Filter
+
+
+@dataclass(frozen=True)
+class StringStartsWith(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to
+ a string that starts with `value`.
+ """
+
+ attribute: ColumnPath
+ value: str
+
+
+@dataclass(frozen=True)
+class StringEndsWith(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to
+ a string that ends with `value`.
+ """
+
+ attribute: ColumnPath
+ value: str
+
+
+@dataclass(frozen=True)
+class StringContains(Filter):
+ """
+ A filter that evaluates to `True` iff the attribute evaluates to
+ a string that contains the string `value`.
+ """
+
+ attribute: ColumnPath
+ value: str
+
+
class InputPartition:
"""
A base class representing an input partition returned by the `partitions()`
@@ -280,6 +504,67 @@ class DataSourceReader(ABC):
.. versionadded: 4.0.0
"""
+ def pushFilters(self, filters: List["Filter"]) -> Iterable["Filter"]:
+ """
+ Called with the list of filters that can be pushed down to the data source.
+
+ The list of filters should be interpreted as the AND of the elements.
+
+ Filter pushdown allows the data source to handle a subset of filters. This
+ can improve performance by reducing the amount of data that needs to be
+ processed by Spark.
+
+ This method is called once during query planning. By default, it returns
+ all filters, indicating that no filters can be pushed down. Subclasses can
+ override this method to implement filter pushdown.
+
+ It's recommended to implement this method only for data sources that natively
+ support filtering, such as databases and GraphQL APIs.
+
+ .. versionadded: 4.1.0
+
+ Parameters
+ ----------
+ filters : list of :class:`Filter`\\s
+
+ Returns
+ -------
+ iterable of :class:`Filter`\\s
+ Filters that still need to be evaluated by Spark post the data source
+ scan. This includes unsupported filters and partially pushed filters.
+ Every returned filter must be one of the input filters by reference.
+
+ Side effects
+ ------------
+ This method is allowed to modify `self`. The object must remain picklable.
+ Modifications to `self` are visible to the `partitions()` and `read()` methods.
+
+ Examples
+ --------
+ Example filters and the resulting arguments passed to pushFilters:
+
+ +-------------------------------+---------------------------------------------+
+ | Filters | Pushdown Arguments |
+ +-------------------------------+---------------------------------------------+
+ | `a = 1 and b = 2` | `[EqualTo(("a",), 1), EqualTo(("b",), 2)]` |
+ | `a = 1 or b = 2` | `[]` |
+ | `a = 1 or (b = 2 and c = 3)` | `[]` |
+ | `a = 1 and (b = 2 or c = 3)` | `[EqualTo(("a",), 1)]` |
+ +-------------------------------+---------------------------------------------+
+
+ Implement pushFilters to support EqualTo filters only:
+
+ >>> def pushFilters(self, filters):
+ ... for filter in filters:
+ ... if isinstance(filter, EqualTo):
+ ... # Save supported filter for handling in partitions() and read()
+ ... self.filters.append(filter)
+ ... else:
+ ... # Unsupported filter
+ ... yield filter
+ """
+ return filters
+
def partitions(self) -> Sequence[InputPartition]:
"""
Returns an iterator of partitions for this data source.
diff --git a/python/pyspark/sql/functions/__init__.py b/python/pyspark/sql/functions/__init__.py
index fc0120bc681d8..8ab2ac377c2a8 100644
--- a/python/pyspark/sql/functions/__init__.py
+++ b/python/pyspark/sql/functions/__init__.py
@@ -98,6 +98,7 @@
"power",
"radians",
"rand",
+ # "random": Excluded because of the name conflict with builtin random module
"randn",
"rint",
"round",
@@ -125,6 +126,7 @@
"char",
"char_length",
"character_length",
+ # "chr": Excluded because of the name conflict with builtin chr function
"collate",
"collation",
"concat_ws",
@@ -153,6 +155,7 @@
"overlay",
"position",
"printf",
+ "quote",
"randstr",
"regexp_count",
"regexp_extract",
@@ -490,6 +493,7 @@
"try_reflect",
"typeof",
"user",
+ # "uuid": Excluded because of the name conflict with builtin uuid module
"version",
# UDF, UDTF and UDT
"AnalyzeArgument",
diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py
index 34cf38bafdc68..40c61caffeac7 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -7433,7 +7433,7 @@ def nanvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
@_try_remote_functions
def percentile(
col: "ColumnOrName",
- percentage: Union[Column, float, Sequence[float], Tuple[float]],
+ percentage: Union[Column, float, Sequence[float], Tuple[float, ...]],
frequency: Union[Column, int] = 1,
) -> Column:
"""Returns the exact percentile(s) of numeric column `expr` at the given percentage(s)
@@ -7493,7 +7493,7 @@ def percentile(
@_try_remote_functions
def percentile_approx(
col: "ColumnOrName",
- percentage: Union[Column, float, Sequence[float], Tuple[float]],
+ percentage: Union[Column, float, Sequence[float], Tuple[float, ...]],
accuracy: Union[Column, int] = 10000,
) -> Column:
"""Returns the approximate `percentile` of the numeric column `col` which is the smallest value
@@ -7564,7 +7564,7 @@ def percentile_approx(
@_try_remote_functions
def approx_percentile(
col: "ColumnOrName",
- percentage: Union[Column, float, Sequence[float], Tuple[float]],
+ percentage: Union[Column, float, Sequence[float], Tuple[float, ...]],
accuracy: Union[Column, int] = 10000,
) -> Column:
"""Returns the approximate `percentile` of the numeric column `col` which is the smallest value
@@ -7687,6 +7687,9 @@ def rand(seed: Optional[int] = None) -> Column:
return _invoke_function("rand")
+random = rand
+
+
@_try_remote_functions
def randn(seed: Optional[int] = None) -> Column:
"""Generates a random column with independent and identically distributed (i.i.d.) samples
@@ -13190,6 +13193,30 @@ def session_user() -> Column:
return _invoke_function("session_user")
+@_try_remote_functions
+def uuid() -> Column:
+ """Returns an universally unique identifier (UUID) string.
+ The value is returned as a canonical UUID 36-character string.
+
+ .. versionadded:: 4.1.0
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.range(5).select(sf.uuid()).show(truncate=False) # doctest: +SKIP
+ +------------------------------------+
+ |uuid() |
+ +------------------------------------+
+ |627ae05e-b319-42b5-b4e4-71c8c9754dd1|
+ |f781cce5-a2e2-464d-bc8b-426ff448e404|
+ |15e2e66e-8416-4ea2-af3c-409363408189|
+ |fb1d6178-7676-4791-baa9-f2ddcc494515|
+ |d48665e8-2657-4c6b-b7c8-8ae0cd646e41|
+ +------------------------------------+
+ """
+ return _invoke_function("uuid")
+
+
@_try_remote_functions
def crc32(col: "ColumnOrName") -> Column:
"""
@@ -17158,6 +17185,41 @@ def character_length(str: "ColumnOrName") -> Column:
return _invoke_function_over_columns("character_length", str)
+@_try_remote_functions
+def chr(n: "ColumnOrName") -> Column:
+ """
+ Returns the ASCII character having the binary equivalent to `n`.
+ If n is larger than 256 the result is equivalent to chr(n % 256).
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ n : :class:`~pyspark.sql.Column` or column name
+ target column to compute on.
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.range(60, 70).select("*", sf.chr("id")).show()
+ +---+-------+
+ | id|chr(id)|
+ +---+-------+
+ | 60| <|
+ | 61| =|
+ | 62| >|
+ | 63| ?|
+ | 64| @|
+ | 65| A|
+ | 66| B|
+ | 67| C|
+ | 68| D|
+ | 69| E|
+ +---+-------+
+ """
+ return _invoke_function_over_columns("chr", n)
+
+
@_try_remote_functions
def try_to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> Column:
"""
@@ -26419,8 +26481,7 @@ def udf(
Defaults to :class:`StringType`.
useArrow : bool, optional
whether to use Arrow to optimize the (de)serialization. When it is None, the
- Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect,
- which is "true" by default.
+ Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect.
Examples
--------
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi
index 750e3f52709c2..193d368f6ebde 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -57,6 +57,8 @@ ArrowGroupedMapUDFType = Literal[209]
ArrowCogroupedMapUDFType = Literal[210]
PandasGroupedMapUDFTransformWithStateType = Literal[211]
PandasGroupedMapUDFTransformWithStateInitStateType = Literal[212]
+GroupedMapUDFTransformWithStateType = Literal[213]
+GroupedMapUDFTransformWithStateInitStateType = Literal[214]
class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py
index 4a1528bda0237..be8ffacfa3d7b 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -366,7 +366,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
# Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be
# used in `returnType`.
# Note: The values inside of the table are generated by `repr`.
- # Note: Python 3.9.5, Pandas 1.4.0 and PyArrow 6.0.1 are used.
+ # Note: Python 3.11.9, Pandas 2.2.3 and PyArrow 17.0.0 are used.
# Note: Timezone is KST.
# Note: 'X' means it throws an exception during the conversion.
require_minimum_pandas_version()
@@ -415,6 +415,8 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
None,
@@ -457,6 +459,8 @@ def _validate_pandas_udf(f, evalType) -> int:
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py
index cd384000f8593..5fe711f742ce6 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
import sys
-from typing import List, Optional, Union, TYPE_CHECKING, cast
+from typing import List, Optional, Union, TYPE_CHECKING, cast, Any
import warnings
from pyspark.errors import PySparkTypeError
@@ -481,10 +481,162 @@ def transformWithStateInPandas(
-----
This function requires a full shuffle.
"""
+ return self.__transformWithState(
+ statefulProcessor,
+ outputStructType,
+ outputMode,
+ timeMode,
+ True,
+ initialState,
+ eventTimeColumnName,
+ )
+
+ def transformWithState(
+ self,
+ statefulProcessor: StatefulProcessor,
+ outputStructType: Union[StructType, str],
+ outputMode: str,
+ timeMode: str,
+ initialState: Optional["GroupedData"] = None,
+ eventTimeColumnName: str = "",
+ ) -> "DataFrame":
+ """
+ Invokes methods defined in the stateful processor used in arbitrary state API v2. It
+ requires protobuf and pyarrow as dependencies to process input/state data. We allow
+ the user to act on per-group set of input rows along with keyed state and the user
+ can choose to output/return 0 or more rows.
+
+ For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
+ in each trigger and the user's state/state variables will be stored persistently across
+ invocations.
+
+ The `statefulProcessor` should be a Python class that implements the interface defined in
+ :class:`StatefulProcessor`.
+
+ The `outputStructType` should be a :class:`StructType` describing the schema of all
+ elements in the returned value, `Row`. The column labels of all elements in
+ returned `Row` must either match the field names in the defined schema.
+
+ The number of `Row`s in the iterator in both the input and output can be arbitrary.
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ statefulProcessor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor`
+ Instance of StatefulProcessor whose functions will be invoked by the operator.
+ outputStructType : :class:`pyspark.sql.types.DataType` or str
+ The type of the output records. The value can be either a
+ :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+ outputMode : str
+ The output mode of the stateful processor.
+ timeMode : str
+ The time mode semantics of the stateful processor for timers and TTL.
+ initialState : :class:`pyspark.sql.GroupedData`
+ Optional. The grouped dataframe as initial states used for initialization
+ of state variables in the first batch.
+
+ Examples
+ --------
+ >>> from typing import Iterator
+ ...
+ >>> from pyspark.sql import Row
+ >>> from pyspark.sql.functions import col, split
+ >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
+ >>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType
+ ...
+ >>> spark.conf.set("spark.sql.streaming.stateStore.providerClass",
+ ... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
+ ... # Below is a simple example to find erroneous sensors from temperature sensor data. The
+ ... # processor returns a count of total readings, while keeping erroneous reading counts
+ ... # in streaming state. A violation is defined when the temperature is above 100.
+ ... # The input data is a DataFrame with the following schema:
+ ... # `id: string, temperature: long`.
+ ... # The output schema and state schema are defined as below.
+ >>> output_schema = StructType([
+ ... StructField("id", StringType(), True),
+ ... StructField("count", IntegerType(), True)
+ ... ])
+ >>> state_schema = StructType([
+ ... StructField("value", IntegerType(), True)
+ ... ])
+ >>> class SimpleStatefulProcessor(StatefulProcessor):
+ ... def init(self, handle: StatefulProcessorHandle):
+ ... self.num_violations_state = handle.getValueState("numViolations", state_schema)
+ ...
+ ... def handleInputRows(self, key, rows):
+ ... new_violations = 0
+ ... count = 0
+ ... exists = self.num_violations_state.exists()
+ ... if exists:
+ ... existing_violations_row = self.num_violations_state.get()
+ ... existing_violations = existing_violations_row[0]
+ ... else:
+ ... existing_violations = 0
+ ... for row in rows:
+ ... if row.temperature is not None:
+ ... count += 1
+ ... if row.temperature > 100:
+ ... new_violations += 1
+ ... updated_violations = new_violations + existing_violations
+ ... self.num_violations_state.update((updated_violations,))
+ ... yield Row(id=key, count=count)
+ ...
+ ... def close(self) -> None:
+ ... pass
+
+ Input DataFrame:
+ +---+-----------+
+ | id|temperature|
+ +---+-----------+
+ | 0| 123|
+ | 0| 23|
+ | 1| 33|
+ | 1| 188|
+ | 1| 88|
+ +---+-----------+
+
+ >>> df.groupBy("value").transformWithState(statefulProcessor =
+ ... SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update",
+ ... timeMode="None") # doctest: +SKIP
+
+ Output DataFrame:
+ +---+-----+
+ | id|count|
+ +---+-----+
+ | 0| 2|
+ | 1| 3|
+ +---+-----+
+
+ Notes
+ -----
+ This function requires a full shuffle.
+ """
+ return self.__transformWithState(
+ statefulProcessor,
+ outputStructType,
+ outputMode,
+ timeMode,
+ False,
+ initialState,
+ eventTimeColumnName,
+ )
+
+ def __transformWithState(
+ self,
+ statefulProcessor: StatefulProcessor,
+ outputStructType: Union[StructType, str],
+ outputMode: str,
+ timeMode: str,
+ usePandas: bool,
+ initialState: Optional["GroupedData"],
+ eventTimeColumnName: str,
+ ) -> "DataFrame":
from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf
- from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasUdfUtils
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPySparkUdfUtils,
+ )
assert isinstance(self, GroupedData)
if initialState is not None:
@@ -493,33 +645,55 @@ def transformWithStateInPandas(
outputStructType = cast(StructType, self._df._session._parse_ddl(outputStructType))
df = self._df
- udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode)
+ udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, timeMode)
+
+ # explicitly set the type to Any since it could match to various types (literals)
+ functionType: Any = None
+ if usePandas and initialState is None:
+ functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
+ elif usePandas and initialState is not None:
+ functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
+ elif not usePandas and initialState is None:
+ functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF
+ else:
+ # not usePandas and initialState is not None
+ functionType = PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF
if initialState is None:
initial_state_java_obj = None
udf = pandas_udf(
udf_util.transformWithStateUDF, # type: ignore
returnType=outputStructType,
- functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
+ functionType=functionType,
)
else:
initial_state_java_obj = initialState._jgd
udf = pandas_udf(
udf_util.transformWithStateWithInitStateUDF, # type: ignore
returnType=outputStructType,
- functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
+ functionType=functionType,
)
udf_column = udf(*[df[col] for col in df.columns])
- jdf = self._jgd.transformWithStateInPandas(
- udf_column._jc,
- self.session._jsparkSession.parseDataType(outputStructType.json()),
- outputMode,
- timeMode,
- initial_state_java_obj,
- eventTimeColumnName,
- )
+ if usePandas:
+ jdf = self._jgd.transformWithStateInPandas(
+ udf_column._jc,
+ self.session._jsparkSession.parseDataType(outputStructType.json()),
+ outputMode,
+ timeMode,
+ initial_state_java_obj,
+ eventTimeColumnName,
+ )
+ else:
+ jdf = self._jgd.transformWithStateInPySpark(
+ udf_column._jc,
+ self.session._jsparkSession.parseDataType(outputStructType.json()),
+ outputMode,
+ timeMode,
+ initial_state_java_obj,
+ eventTimeColumnName,
+ )
return DataFrame(jdf, self.session)
def applyInArrow(
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 74d9a2ce65608..4b65059d5e709 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -20,6 +20,7 @@
"""
from itertools import groupby
+from typing import TYPE_CHECKING, Optional
import pyspark
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
@@ -31,6 +32,7 @@
UTF8Deserializer,
CPickleSerializer,
)
+from pyspark.sql import Row
from pyspark.sql.pandas.types import (
from_arrow_type,
is_variant,
@@ -48,6 +50,10 @@
IntegerType,
)
+if TYPE_CHECKING:
+ import pandas as pd
+ import pyarrow as pa
+
class SpecialLengths:
END_OF_DATA_SECTION = -1
@@ -237,7 +243,9 @@ def __init__(self, timezone, safecheck):
self._timezone = timezone
self._safecheck = safecheck
- def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict", ndarray_as_list=False):
+ def arrow_to_pandas(
+ self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None
+ ):
# If the given column is a date type column, creates a series of datetime.date directly
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
# datetime64[ns] type handling.
@@ -260,7 +268,7 @@ def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict", ndarray_as_list
# TODO(SPARK-43579): cache the converter for reuse
converter = _create_converter_to_pandas(
- data_type=from_arrow_type(arrow_column.type, prefer_timestamp_ntz=True),
+ data_type=spark_type or from_arrow_type(arrow_column.type, prefer_timestamp_ntz=True),
nullable=True,
timezone=self._timezone,
struct_in_pandas=struct_in_pandas,
@@ -394,7 +402,8 @@ def load_stream(self, stream):
for batch in batches:
pandas_batches = [
- self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()
+ self.arrow_to_pandas(c, i)
+ for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
]
if len(pandas_batches) == 0:
yield [pd.Series([pyspark._NoValue] * batch.num_rows)]
@@ -419,6 +428,7 @@ def __init__(
struct_in_pandas="dict",
ndarray_as_list=False,
arrow_cast=False,
+ input_types=None,
):
super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
self._assign_cols_by_name = assign_cols_by_name
@@ -426,8 +436,9 @@ def __init__(
self._struct_in_pandas = struct_in_pandas
self._ndarray_as_list = ndarray_as_list
self._arrow_cast = arrow_cast
+ self._input_types = input_types
- def arrow_to_pandas(self, arrow_column):
+ def arrow_to_pandas(self, arrow_column, idx):
import pyarrow.types as types
# If the arrow type is struct, return a pandas dataframe where the fields of the struct
@@ -442,18 +453,37 @@ def arrow_to_pandas(self, arrow_column):
series = [
super(ArrowStreamPandasUDFSerializer, self)
- .arrow_to_pandas(column, self._struct_in_pandas, self._ndarray_as_list)
+ .arrow_to_pandas(
+ column,
+ i,
+ self._struct_in_pandas,
+ self._ndarray_as_list,
+ spark_type=(
+ self._input_types[idx][i].dataType
+ if self._input_types is not None
+ else None
+ ),
+ )
.rename(field.name)
- for column, field in zip(arrow_column.flatten(), arrow_column.type)
+ for i, (column, field) in enumerate(zip(arrow_column.flatten(), arrow_column.type))
]
s = pd.concat(series, axis=1)
else:
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(
- arrow_column, self._struct_in_pandas, self._ndarray_as_list
+ arrow_column,
+ idx,
+ self._struct_in_pandas,
+ self._ndarray_as_list,
+ spark_type=self._input_types[idx] if self._input_types is not None else None,
)
return s
- def _create_struct_array(self, df, arrow_struct_type):
+ def _create_struct_array(
+ self,
+ df: "pd.DataFrame",
+ arrow_struct_type: "pa.StructType",
+ spark_type: Optional[StructType] = None,
+ ):
"""
Create an Arrow StructArray from the given pandas.DataFrame and arrow struct type.
@@ -461,7 +491,7 @@ def _create_struct_array(self, df, arrow_struct_type):
----------
df : pandas.DataFrame
A pandas DataFrame
- arrow_struct_type : pyarrow.DataType
+ arrow_struct_type : pyarrow.StructType
pyarrow struct type
Returns
@@ -475,7 +505,14 @@ def _create_struct_array(self, df, arrow_struct_type):
# Assign result columns by schema name if user labeled with strings
if self._assign_cols_by_name and any(isinstance(name, str) for name in df.columns):
struct_arrs = [
- self._create_array(df[field.name], field.type, arrow_cast=self._arrow_cast)
+ self._create_array(
+ df[field.name],
+ field.type,
+ spark_type=(
+ spark_type[field.name].dataType if spark_type is not None else None
+ ),
+ arrow_cast=self._arrow_cast,
+ )
for field in arrow_struct_type
]
# Assign result columns by position
@@ -486,13 +523,13 @@ def _create_struct_array(self, df, arrow_struct_type):
self._create_array(
df[df.columns[i]].rename(field.name),
field.type,
+ spark_type=spark_type[i].dataType if spark_type is not None else None,
arrow_cast=self._arrow_cast,
)
for i, field in enumerate(arrow_struct_type)
]
- struct_names = [field.name for field in arrow_struct_type]
- return pa.StructArray.from_arrays(struct_arrs, struct_names)
+ return pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_struct_type))
def _create_batch(self, series):
"""
@@ -513,23 +550,31 @@ def _create_batch(self, series):
import pandas as pd
import pyarrow as pa
- # Make input conform to [(series1, type1), (series2, type2), ...]
- if not isinstance(series, (list, tuple)) or (
- len(series) == 2 and isinstance(series[1], pa.DataType)
+ # Make input conform to
+ # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...]
+ if (
+ not isinstance(series, (list, tuple))
+ or (len(series) == 2 and isinstance(series[1], pa.DataType))
+ or (
+ len(series) == 3
+ and isinstance(series[1], pa.DataType)
+ and isinstance(series[2], DataType)
+ )
):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
+ series = ((s[0], s[1], None) if len(s) == 2 else s for s in series)
arrs = []
- for s, t in series:
+ for s, arrow_type, spark_type in series:
# Variants are represented in arrow as structs with additional metadata (checked by
# is_variant). If the data type is Variant, return a VariantVal atomic type instead of
# a dict of two binary values.
if (
self._struct_in_pandas == "dict"
- and t is not None
- and pa.types.is_struct(t)
- and not is_variant(t)
+ and arrow_type is not None
+ and pa.types.is_struct(arrow_type)
+ and not is_variant(arrow_type)
):
# A pandas UDF should return pd.DataFrame when the return type is a struct type.
# If it returns a pd.Series, it should throw an error.
@@ -538,9 +583,13 @@ def _create_batch(self, series):
"Invalid return type. Please make sure that the UDF returns a "
"pandas.DataFrame when the specified return type is StructType."
)
- arrs.append(self._create_struct_array(s, t))
+ arrs.append(self._create_struct_array(s, arrow_type, spark_type=spark_type))
else:
- arrs.append(self._create_array(s, t, arrow_cast=self._arrow_cast))
+ arrs.append(
+ self._create_array(
+ s, arrow_type, spark_type=spark_type, arrow_cast=self._arrow_cast
+ )
+ )
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
@@ -765,8 +814,14 @@ def load_stream(self, stream):
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
yield (
- [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
- [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()],
+ [
+ self.arrow_to_pandas(c, i)
+ for i, c in enumerate(pa.Table.from_batches(batch1).itercolumns())
+ ],
+ [
+ self.arrow_to_pandas(c, i)
+ for i, c in enumerate(pa.Table.from_batches(batch2).itercolumns())
+ ],
)
elif dataframes_in_group != 0:
@@ -926,7 +981,7 @@ def gen_data_and_state(batches):
)
state_arrow = pa.Table.from_batches([state_batch]).itercolumns()
- state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0]
+ state_pandas = [self.arrow_to_pandas(c, i) for i, c in enumerate(state_arrow)][0]
for state_idx in range(0, len(state_pandas)):
state_info_col = state_pandas.iloc[state_idx]
@@ -958,7 +1013,7 @@ def gen_data_and_state(batches):
data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows)
data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns()
- data_pandas = [self.arrow_to_pandas(c) for c in data_arrow]
+ data_pandas = [self.arrow_to_pandas(c, i) for i, c in enumerate(data_arrow)]
# state info
yield (
@@ -1182,17 +1237,19 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_p
def load_stream(self, stream):
"""
Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and
- convert the data into a list of pandas.Series.
+ convert the data into Rows.
Please refer the doc of inner function `generate_data_batches` for more details how
this function works in overall.
"""
import pyarrow as pa
- from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPySparkFuncMode,
+ )
def generate_data_batches(batches):
"""
- Deserialize ArrowRecordBatches and return a generator of pandas.Series list.
+ Deserialize ArrowRecordBatches and return a generator of Rows.
The deserialization logic assumes that Arrow RecordBatches contain the data with the
ordering that data chunks for same grouping key will appear sequentially.
@@ -1202,7 +1259,8 @@ def generate_data_batches(batches):
"""
for batch in batches:
data_pandas = [
- self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()
+ self.arrow_to_pandas(c, i)
+ for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
]
key_series = [data_pandas[o] for o in self.key_offsets]
batch_key = tuple(s[0] for s in key_series)
@@ -1212,19 +1270,28 @@ def generate_data_batches(batches):
data_batches = generate_data_batches(_batches)
for k, g in groupby(data_batches, key=lambda x: x[0]):
- yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+ yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g)
- yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
+ yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
- yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
+ yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
def dump_stream(self, iterator, stream):
"""
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow
RecordBatches, and write batches to stream.
"""
- result = [(b, t) for x in iterator for y, t in x for b in y]
- super().dump_stream(result, stream)
+
+ def flatten_iterator():
+ # iterator: iter[list[(iter[pandas.DataFrame], pdf_type)]]
+ for packed in iterator:
+ iter_pdf_with_type = packed[0]
+ iter_pdf = iter_pdf_with_type[0]
+ pdf_type = iter_pdf_with_type[1]
+ for pdf in iter_pdf:
+ yield (pdf, pdf_type)
+
+ super().dump_stream(flatten_iterator(), stream)
class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer):
@@ -1244,7 +1311,9 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_p
def load_stream(self, stream):
import pyarrow as pa
- from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPySparkFuncMode,
+ )
def generate_data_batches(batches):
"""
@@ -1280,11 +1349,15 @@ def flatten_columns(cur_batch, col_name):
"""
for batch in batches:
flatten_state_table = flatten_columns(batch, "inputData")
- data_pandas = [self.arrow_to_pandas(c) for c in flatten_state_table.itercolumns()]
+ data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in enumerate(flatten_state_table.itercolumns())
+ ]
flatten_init_table = flatten_columns(batch, "initState")
init_data_pandas = [
- self.arrow_to_pandas(c) for c in flatten_init_table.itercolumns()
+ self.arrow_to_pandas(c, i)
+ for i, c in enumerate(flatten_init_table.itercolumns())
]
key_series = [data_pandas[o] for o in self.key_offsets]
init_key_series = [init_data_pandas[o] for o in self.init_key_offsets]
@@ -1301,8 +1374,212 @@ def flatten_columns(cur_batch, col_name):
data_batches = generate_data_batches(_batches)
for k, g in groupby(data_batches, key=lambda x: x[0]):
- yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+ yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g)
+
+ yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+
+ yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+
+
+class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
+ """
+ Serializer used by Python worker to evaluate UDF for
+ :meth:`pyspark.sql.GroupedData.transformWithState`.
+
+ Parameters
+ ----------
+ arrow_max_records_per_batch : int
+ Limit of the number of records that can be written to a single ArrowRecordBatch in memory.
+ """
+
+ def __init__(self, arrow_max_records_per_batch):
+ super(TransformWithStateInPySparkRowSerializer, self).__init__()
+ self.arrow_max_records_per_batch = arrow_max_records_per_batch
+ self.key_offsets = None
+
+ def load_stream(self, stream):
+ """
+ Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunks,
+ and convert the data into a list of pandas.Series.
+
+ Please refer the doc of inner function `generate_data_batches` for more details how
+ this function works in overall.
+ """
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPySparkFuncMode,
+ )
+ import itertools
+
+ def generate_data_batches(batches):
+ """
+ Deserialize ArrowRecordBatches and return a generator of Row.
+
+ The deserialization logic assumes that Arrow RecordBatches contain the data with the
+ ordering that data chunks for same grouping key will appear sequentially.
+
+ This function must avoid materializing multiple Arrow RecordBatches into memory at the
+ same time. And data chunks from the same grouping key should appear sequentially.
+ """
+ for batch in batches:
+ DataRow = Row(*(batch.schema.names))
+
+ # This is supposed to be the same.
+ batch_key = tuple(batch[o][0].as_py() for o in self.key_offsets)
+ for row_idx in range(batch.num_rows):
+ row = DataRow(
+ *(batch.column(i)[row_idx].as_py() for i in range(batch.num_columns))
+ )
+ yield (batch_key, row)
+
+ _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
+ data_batches = generate_data_batches(_batches)
+
+ for k, g in groupby(data_batches, key=lambda x: x[0]):
+ chained = itertools.chain(g)
+ chained_values = map(lambda x: x[1], chained)
+ yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, chained_values)
+
+ yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+
+ yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+
+ def dump_stream(self, iterator, stream):
+ """
+ Read through an iterator of (iterator of Row), serialize them to Arrow
+ RecordBatches, and write batches to stream.
+ """
+ import pyarrow as pa
+
+ def flatten_iterator():
+ # iterator: iter[list[(iter[Row], pdf_type)]]
+ for packed in iterator:
+ iter_row_with_type = packed[0]
+ iter_row = iter_row_with_type[0]
+ pdf_type = iter_row_with_type[1]
+
+ rows_as_dict = []
+ for row in iter_row:
+ row_as_dict = row.asDict(True)
+ rows_as_dict.append(row_as_dict)
+
+ pdf_schema = pa.schema(list(pdf_type))
+ record_batch = pa.RecordBatch.from_pylist(rows_as_dict, schema=pdf_schema)
+
+ yield (record_batch, pdf_type)
+
+ return ArrowStreamUDFSerializer.dump_stream(self, flatten_iterator(), stream)
+
+
+class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySparkRowSerializer):
+ """
+ Serializer used by Python worker to evaluate UDF for
+ :meth:`pyspark.sql.GroupedData.transformWithStateInPySparkRowInitStateSerializer`.
+ Parameters
+ ----------
+ Same as input parameters in TransformWithStateInPySparkRowSerializer.
+ """
+
+ def __init__(self, arrow_max_records_per_batch):
+ super(TransformWithStateInPySparkRowInitStateSerializer, self).__init__(
+ arrow_max_records_per_batch
+ )
+ self.init_key_offsets = None
+
+ def load_stream(self, stream):
+ import itertools
+ import pyarrow as pa
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPySparkFuncMode,
+ )
+
+ def generate_data_batches(batches):
+ """
+ Deserialize ArrowRecordBatches and return a generator of Row.
+ The deserialization logic assumes that Arrow RecordBatches contain the data with the
+ ordering that data chunks for same grouping key will appear sequentially.
+ See `TransformWithStateInPySparkPythonInitialStateRunner` for arrow batch schema sent
+ from JVM.
+ This function flattens the columns of input rows and initial state rows and feed them
+ into the data generator.
+ """
+
+ def extract_rows(cur_batch, col_name, key_offsets):
+ data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))
+ data_field_names = [
+ data_column.type[i].name for i in range(data_column.type.num_fields)
+ ]
+ data_field_arrays = [
+ data_column.field(i) for i in range(data_column.type.num_fields)
+ ]
+
+ DataRow = Row(*data_field_names)
+
+ table = pa.Table.from_arrays(data_field_arrays, names=data_field_names)
+
+ if table.num_rows == 0:
+ return (None, iter([]))
+ else:
+ batch_key = tuple(table.column(o)[0].as_py() for o in key_offsets)
+
+ rows = []
+ for row_idx in range(table.num_rows):
+ row = DataRow(
+ *(table.column(i)[row_idx].as_py() for i in range(table.num_columns))
+ )
+ rows.append(row)
+
+ return (batch_key, iter(rows))
+
+ """
+ The arrow batch is written in the schema:
+ schema: StructType = new StructType()
+ .add("inputData", dataSchema)
+ .add("initState", initStateSchema)
+ We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python
+ data generator. All rows in the same batch have the same grouping key.
+ """
+ for batch in batches:
+ (input_batch_key, input_data_iter) = extract_rows(
+ batch, "inputData", self.key_offsets
+ )
+ (init_batch_key, init_state_iter) = extract_rows(
+ batch, "initState", self.init_key_offsets
+ )
+
+ if input_batch_key is None:
+ batch_key = init_batch_key
+ else:
+ batch_key = input_batch_key
+
+ for init_state_row in init_state_iter:
+ yield (batch_key, None, init_state_row)
+
+ for input_data_row in input_data_iter:
+ yield (batch_key, input_data_row, None)
+
+ _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
+ data_batches = generate_data_batches(_batches)
+
+ for k, g in groupby(data_batches, key=lambda x: x[0]):
+ # g: list(batch_key, input_data_iter, init_state_iter)
+
+ # they are sharing the iterator, hence need to copy
+ input_values_iter, init_state_iter = itertools.tee(g, 2)
+
+ chained_input_values = itertools.chain(map(lambda x: x[1], input_values_iter))
+ chained_init_state_values = itertools.chain(map(lambda x: x[2], init_state_iter))
+
+ chained_input_values_without_none = filter(
+ lambda x: x is not None, chained_input_values
+ )
+ chained_init_state_values_without_none = filter(
+ lambda x: x is not None, chained_init_state_values
+ )
+
+ ret_tuple = (chained_input_values_without_none, chained_init_state_values_without_none)
+
+ yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, ret_tuple)
- yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
+ yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
- yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
+ yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py
index fcd70d4d18399..70d336bce85b7 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -1424,6 +1424,12 @@ def _to_numpy_type(type: DataType) -> Optional["np.dtype"]:
return np.dtype("float32")
elif type == DoubleType():
return np.dtype("float64")
+ elif type == TimestampType():
+ return np.dtype("datetime64[us]")
+ elif type == TimestampNTZType():
+ return np.dtype("datetime64[us]")
+ elif type == DayTimeIntervalType():
+ return np.dtype("timedelta64[us]")
return None
@@ -1432,7 +1438,18 @@ def convert_pandas_using_numpy_type(
) -> "PandasDataFrameLike":
for field in schema.fields:
if isinstance(
- field.dataType, (ByteType, ShortType, LongType, FloatType, DoubleType, IntegerType)
+ field.dataType,
+ (
+ ByteType,
+ ShortType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType,
+ TimestampType,
+ TimestampNTZType,
+ DayTimeIntervalType,
+ ),
):
np_type = _to_numpy_type(field.dataType)
df[field.name] = df[field.name].astype(np_type)
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
index 5247044984a6b..8fb0457a9a035 100644
--- a/python/pyspark/sql/plot/core.py
+++ b/python/pyspark/sql/plot/core.py
@@ -704,29 +704,27 @@ def compute_box(
lfence = q1 - F.lit(whis) * iqr
ufence = q3 + F.lit(whis) * iqr
- stats_scols.append(
- F.struct(
- F.mean(colname).alias("mean"),
- med.alias("med"),
- q1.alias("q1"),
- q3.alias("q3"),
- lfence.alias("lfence"),
- ufence.alias("ufence"),
- ).alias(f"_box_plot_stats_{i}")
- )
+ stats_scols.append(F.mean(colname).alias(f"mean_{i}"))
+ stats_scols.append(med.alias(f"med_{i}"))
+ stats_scols.append(q1.alias(f"q1_{i}"))
+ stats_scols.append(q3.alias(f"q3_{i}"))
+ stats_scols.append(lfence.alias(f"lfence_{i}"))
+ stats_scols.append(ufence.alias(f"ufence_{i}"))
- sdf_stats = sdf.select(*stats_scols)
+ # compute all stats with a scalar subquery
+ stats_col = "__pyspark_plotting_box_plot_stats__"
+ sdf = sdf.select("*", sdf.select(F.struct(*stats_scols)).scalar().alias(stats_col))
result_scols = []
for i, colname in enumerate(formatted_colnames):
value = F.col(colname)
- lfence = F.col(f"_box_plot_stats_{i}.lfence")
- ufence = F.col(f"_box_plot_stats_{i}.ufence")
- mean = F.col(f"_box_plot_stats_{i}.mean")
- med = F.col(f"_box_plot_stats_{i}.med")
- q1 = F.col(f"_box_plot_stats_{i}.q1")
- q3 = F.col(f"_box_plot_stats_{i}.q3")
+ lfence = F.col(f"{stats_col}.lfence_{i}")
+ ufence = F.col(f"{stats_col}.ufence_{i}")
+ mean = F.col(f"{stats_col}.mean_{i}")
+ med = F.col(f"{stats_col}.med_{i}")
+ q1 = F.col(f"{stats_col}.q1_{i}")
+ q3 = F.col(f"{stats_col}.q3_{i}")
outlier = ~value.between(lfence, ufence)
@@ -758,5 +756,4 @@ def compute_box(
).alias(f"_box_plot_results_{i}")
)
- sdf_result = sdf.join(sdf_stats.hint("broadcast")).select(*result_scols)
- return sdf_result.first()
+ return sdf.select(*result_scols).first()
diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py
index cb618d1a691b3..08b672e86e08e 100644
--- a/python/pyspark/sql/streaming/list_state_client.py
+++ b/python/pyspark/sql/streaming/list_state_client.py
@@ -14,16 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Dict, Iterator, List, Union, Tuple
+from typing import Any, Dict, Iterator, List, Union, Tuple
from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
-from pyspark.sql.types import StructType, TYPE_CHECKING
+from pyspark.sql.types import StructType
from pyspark.errors import PySparkRuntimeError
import uuid
-if TYPE_CHECKING:
- from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
-
__all__ = ["ListStateClient"]
@@ -38,9 +35,9 @@ def __init__(
self.schema = self._stateful_processor_api_client._parse_string_schema(schema)
else:
self.schema = schema
- # A dictionary to store the mapping between list state name and a tuple of pandas DataFrame
+ # A dictionary to store the mapping between list state name and a tuple of data batch
# and the index of the last row that was read.
- self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+ self.data_batch_dict: Dict[str, Tuple[Any, int, bool]] = {}
def exists(self, state_name: str) -> bool:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -64,12 +61,12 @@ def exists(self, state_name: str) -> bool:
f"Error checking value state exists: " f"{response_message[1]}"
)
- def get(self, state_name: str, iterator_id: str) -> Tuple:
+ def get(self, state_name: str, iterator_id: str) -> Tuple[Tuple, bool]:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
- if iterator_id in self.pandas_df_dict:
+ if iterator_id in self.data_batch_dict:
# If the state is already in the dictionary, return the next row.
- pandas_df, index = self.pandas_df_dict[iterator_id]
+ data_batch, index, require_next_fetch = self.data_batch_dict[iterator_id]
else:
# If the state is not in the dictionary, fetch the state from the server.
get_call = stateMessage.ListStateGet(iteratorId=iterator_id)
@@ -82,36 +79,35 @@ def get(self, state_name: str, iterator_id: str) -> Tuple:
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
- response_message = self._stateful_processor_api_client._receive_proto_message()
+ response_message = (
+ self._stateful_processor_api_client._receive_proto_message_with_list_get()
+ )
status = response_message[0]
if status == 0:
- iterator = self._stateful_processor_api_client._read_arrow_state()
- # We need to exhaust the iterator here to make sure all the arrow batches are read,
- # even though there is only one batch in the iterator. Otherwise, the stream might
- # block further reads since it thinks there might still be some arrow batches left.
- # We only need to read the first batch in the iterator because it's guaranteed that
- # there would only be one batch sent from the JVM side.
- data_batch = None
- for batch in iterator:
- if data_batch is None:
- data_batch = batch
- if data_batch is None:
- # TODO(SPARK-49233): Classify user facing errors.
- raise PySparkRuntimeError("Error getting next list state row.")
- pandas_df = data_batch.to_pandas()
+ data_batch = list(
+ map(
+ lambda x: self._stateful_processor_api_client._deserialize_from_bytes(x),
+ response_message[2],
+ )
+ )
+ require_next_fetch = response_message[3]
index = 0
else:
raise StopIteration()
+ is_last_row = False
new_index = index + 1
- if new_index < len(pandas_df):
+ if new_index < len(data_batch):
# Update the index in the dictionary.
- self.pandas_df_dict[iterator_id] = (pandas_df, new_index)
+ self.data_batch_dict[iterator_id] = (data_batch, new_index, require_next_fetch)
else:
- # If the index is at the end of the DataFrame, remove the state from the dictionary.
- self.pandas_df_dict.pop(iterator_id, None)
- pandas_row = pandas_df.iloc[index]
- return tuple(pandas_row)
+ # If the index is at the end of the data batch, remove the state from the dictionary.
+ self.data_batch_dict.pop(iterator_id, None)
+ is_last_row = True
+
+ is_last_row_from_iterator = is_last_row and not require_next_fetch
+ row = data_batch[index]
+ return (tuple(row), is_last_row_from_iterator)
def append_value(self, state_name: str, value: Tuple) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -134,7 +130,24 @@ def append_value(self, state_name: str, value: Tuple) -> None:
def append_list(self, state_name: str, values: List[Tuple]) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
- append_list_call = stateMessage.AppendList()
+ send_data_via_arrow = False
+
+ # To workaround mypy type assignment check.
+ values_as_bytes: Any = []
+ if len(values) == 100:
+ # TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
+ # value backed by various benchmarks.
+ # Arrow codepath
+ send_data_via_arrow = True
+ else:
+ values_as_bytes = map(
+ lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
+ values,
+ )
+
+ append_list_call = stateMessage.AppendList(
+ value=values_as_bytes, fetchWithArrow=send_data_via_arrow
+ )
list_state_call = stateMessage.ListStateCall(
stateName=state_name, appendList=append_list_call
)
@@ -143,7 +156,9 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None:
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
- self._stateful_processor_api_client._send_arrow_state(self.schema, values)
+ if send_data_via_arrow:
+ self._stateful_processor_api_client._send_arrow_state(self.schema, values)
+
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
@@ -153,14 +168,32 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None:
def put(self, state_name: str, values: List[Tuple]) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
- put_call = stateMessage.ListStatePut()
+ send_data_via_arrow = False
+ # To workaround mypy type assignment check.
+ values_as_bytes: Any = []
+ if len(values) == 100:
+ # TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
+ # value backed by various benchmarks.
+ send_data_via_arrow = True
+ else:
+ values_as_bytes = map(
+ lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
+ values,
+ )
+
+ put_call = stateMessage.ListStatePut(
+ value=values_as_bytes, fetchWithArrow=send_data_via_arrow
+ )
+
list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call)
state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call)
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
- self._stateful_processor_api_client._send_arrow_state(self.schema, values)
+ if send_data_via_arrow:
+ self._stateful_processor_api_client._send_arrow_state(self.schema, values)
+
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
@@ -190,9 +223,17 @@ def __init__(self, list_state_client: ListStateClient, state_name: str):
# Generate a unique identifier for the iterator to make sure iterators from the same
# list state do not interfere with each other.
self.iterator_id = str(uuid.uuid4())
+ self.iterator_fully_consumed = False
def __iter__(self) -> Iterator[Tuple]:
return self
def __next__(self) -> Tuple:
- return self.list_state_client.get(self.state_name, self.iterator_id)
+ if self.iterator_fully_consumed:
+ raise StopIteration()
+
+ row, is_last_row = self.list_state_client.get(self.state_name, self.iterator_id)
+ if is_last_row:
+ self.iterator_fully_consumed = True
+
+ return row
diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
index 20af541f307cd..094f1dd51c584 100644
--- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
@@ -40,7 +40,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\x84\x05\n\x0cStateRequest\x12\x18\n\x07version\x18\x01 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14stateVariableRequest\x12\x8c\x01\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00R\x1aimplicitGroupingKeyRequest\x12\x62\n\x0ctimerRequest\x18\x05 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.TimerRequestH\x00R\x0ctimerRequest\x12\x62\n\x0cutilsRequest\x18\x06 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.UtilsRequestH\x00R\x0cutilsRequestB\x08\n\x06method"i\n\rStateResponse\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x0cR\x05value"x\n\x1cStateResponseWithLongTypeVal\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x03R\x05value"z\n\x1eStateResponseWithStringTypeVal\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\tR\x05value"\xa0\x05\n\x15StatefulProcessorCall\x12h\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00R\x0esetHandleState\x12h\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\rgetValueState\x12\x66\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0cgetListState\x12\x64\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0bgetMapState\x12o\n\x0etimerStateCall\x18\x05 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.TimerStateCallCommandH\x00R\x0etimerStateCall\x12j\n\x0e\x64\x65leteIfExists\x18\x06 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0e\x64\x65leteIfExistsB\x08\n\x06method"\xd5\x02\n\x14StateVariableRequest\x12h\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00R\x0evalueStateCall\x12\x65\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00R\rlistStateCall\x12\x62\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00R\x0cmapStateCallB\x08\n\x06method"\x83\x02\n\x1aImplicitGroupingKeyRequest\x12h\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00R\x0esetImplicitKey\x12q\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00R\x11removeImplicitKeyB\x08\n\x06method"\x81\x02\n\x0cTimerRequest\x12q\n\x11timerValueRequest\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.TimerValueRequestH\x00R\x11timerValueRequest\x12t\n\x12\x65xpiryTimerRequest\x18\x02 \x01(\x0b\x32\x42.org.apache.spark.sql.execution.streaming.state.ExpiryTimerRequestH\x00R\x12\x65xpiryTimerRequestB\x08\n\x06method"\xf6\x01\n\x11TimerValueRequest\x12s\n\x12getProcessingTimer\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.GetProcessingTimeH\x00R\x12getProcessingTimer\x12\x62\n\x0cgetWatermark\x18\x02 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.GetWatermarkH\x00R\x0cgetWatermarkB\x08\n\x06method"B\n\x12\x45xpiryTimerRequest\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs"\x13\n\x11GetProcessingTime"\x0e\n\x0cGetWatermark"\x8b\x01\n\x0cUtilsRequest\x12q\n\x11parseStringSchema\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.ParseStringSchemaH\x00R\x11parseStringSchemaB\x08\n\x06method"+\n\x11ParseStringSchema\x12\x16\n\x06schema\x18\x01 \x01(\tR\x06schema"\xc7\x01\n\x10StateCallCommand\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x30\n\x13mapStateValueSchema\x18\x03 \x01(\tR\x13mapStateValueSchema\x12K\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfigR\x03ttl"\xa7\x02\n\x15TimerStateCallCommand\x12[\n\x08register\x18\x01 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.RegisterTimerH\x00R\x08register\x12U\n\x06\x64\x65lete\x18\x02 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.DeleteTimerH\x00R\x06\x64\x65lete\x12P\n\x04list\x18\x03 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.ListTimersH\x00R\x04listB\x08\n\x06method"\x92\x03\n\x0eValueStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12G\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00R\x03get\x12n\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00R\x10valueStateUpdate\x12M\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xdf\x04\n\rListStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12\x62\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00R\x0clistStateGet\x12\x62\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00R\x0clistStatePut\x12_\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00R\x0b\x61ppendValue\x12\\\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00R\nappendList\x12M\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xc2\x06\n\x0cMapStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12V\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00R\x08getValue\x12_\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00R\x0b\x63ontainsKey\x12_\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00R\x0bupdateValue\x12V\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00R\x08iterator\x12J\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00R\x04keys\x12P\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00R\x06values\x12Y\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00R\tremoveKey\x12M\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method""\n\x0eSetImplicitKey\x12\x10\n\x03key\x18\x01 \x01(\x0cR\x03key"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"=\n\rRegisterTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs";\n\x0b\x44\x65leteTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs",\n\nListTimers\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x10ValueStateUpdate\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x07\n\x05\x43lear".\n\x0cListStateGet\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"\x0e\n\x0cListStatePut"#\n\x0b\x41ppendValue\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x0c\n\nAppendList"$\n\x08GetValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"\'\n\x0b\x43ontainsKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"=\n\x0bUpdateValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey\x12\x14\n\x05value\x18\x02 \x01(\x0cR\x05value"*\n\x08Iterator\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"&\n\x04Keys\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x06Values\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"%\n\tRemoveKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"c\n\x0eSetHandleState\x12Q\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleStateR\x05state"+\n\tTTLConfig\x12\x1e\n\ndurationMs\x18\x01 \x01(\x05R\ndurationMs*n\n\x0bHandleState\x12\x0c\n\x08PRE_INIT\x10\x00\x12\x0b\n\x07\x43REATED\x10\x01\x12\x0f\n\x0bINITIALIZED\x10\x02\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x03\x12\x13\n\x0fTIMER_PROCESSED\x10\x04\x12\n\n\x06\x43LOSED\x10\x05\x62\x06proto3'
+ b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\x84\x05\n\x0cStateRequest\x12\x18\n\x07version\x18\x01 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14stateVariableRequest\x12\x8c\x01\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00R\x1aimplicitGroupingKeyRequest\x12\x62\n\x0ctimerRequest\x18\x05 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.TimerRequestH\x00R\x0ctimerRequest\x12\x62\n\x0cutilsRequest\x18\x06 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.UtilsRequestH\x00R\x0cutilsRequestB\x08\n\x06method"i\n\rStateResponse\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x0cR\x05value"x\n\x1cStateResponseWithLongTypeVal\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x03R\x05value"z\n\x1eStateResponseWithStringTypeVal\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\tR\x05value"\xa0\x01\n\x18StateResponseWithListGet\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x03(\x0cR\x05value\x12*\n\x10requireNextFetch\x18\x04 \x01(\x08R\x10requireNextFetch"\xa0\x05\n\x15StatefulProcessorCall\x12h\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00R\x0esetHandleState\x12h\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\rgetValueState\x12\x66\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0cgetListState\x12\x64\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0bgetMapState\x12o\n\x0etimerStateCall\x18\x05 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.TimerStateCallCommandH\x00R\x0etimerStateCall\x12j\n\x0e\x64\x65leteIfExists\x18\x06 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0e\x64\x65leteIfExistsB\x08\n\x06method"\xd5\x02\n\x14StateVariableRequest\x12h\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00R\x0evalueStateCall\x12\x65\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00R\rlistStateCall\x12\x62\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00R\x0cmapStateCallB\x08\n\x06method"\x83\x02\n\x1aImplicitGroupingKeyRequest\x12h\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00R\x0esetImplicitKey\x12q\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00R\x11removeImplicitKeyB\x08\n\x06method"\x81\x02\n\x0cTimerRequest\x12q\n\x11timerValueRequest\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.TimerValueRequestH\x00R\x11timerValueRequest\x12t\n\x12\x65xpiryTimerRequest\x18\x02 \x01(\x0b\x32\x42.org.apache.spark.sql.execution.streaming.state.ExpiryTimerRequestH\x00R\x12\x65xpiryTimerRequestB\x08\n\x06method"\xf6\x01\n\x11TimerValueRequest\x12s\n\x12getProcessingTimer\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.GetProcessingTimeH\x00R\x12getProcessingTimer\x12\x62\n\x0cgetWatermark\x18\x02 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.GetWatermarkH\x00R\x0cgetWatermarkB\x08\n\x06method"B\n\x12\x45xpiryTimerRequest\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs"\x13\n\x11GetProcessingTime"\x0e\n\x0cGetWatermark"\x8b\x01\n\x0cUtilsRequest\x12q\n\x11parseStringSchema\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.ParseStringSchemaH\x00R\x11parseStringSchemaB\x08\n\x06method"+\n\x11ParseStringSchema\x12\x16\n\x06schema\x18\x01 \x01(\tR\x06schema"\xc7\x01\n\x10StateCallCommand\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x30\n\x13mapStateValueSchema\x18\x03 \x01(\tR\x13mapStateValueSchema\x12K\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfigR\x03ttl"\xa7\x02\n\x15TimerStateCallCommand\x12[\n\x08register\x18\x01 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.RegisterTimerH\x00R\x08register\x12U\n\x06\x64\x65lete\x18\x02 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.DeleteTimerH\x00R\x06\x64\x65lete\x12P\n\x04list\x18\x03 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.ListTimersH\x00R\x04listB\x08\n\x06method"\x92\x03\n\x0eValueStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12G\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00R\x03get\x12n\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00R\x10valueStateUpdate\x12M\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xdf\x04\n\rListStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12\x62\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00R\x0clistStateGet\x12\x62\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00R\x0clistStatePut\x12_\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00R\x0b\x61ppendValue\x12\\\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00R\nappendList\x12M\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xc2\x06\n\x0cMapStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12V\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00R\x08getValue\x12_\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00R\x0b\x63ontainsKey\x12_\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00R\x0bupdateValue\x12V\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00R\x08iterator\x12J\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00R\x04keys\x12P\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00R\x06values\x12Y\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00R\tremoveKey\x12M\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method""\n\x0eSetImplicitKey\x12\x10\n\x03key\x18\x01 \x01(\x0cR\x03key"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"=\n\rRegisterTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs";\n\x0b\x44\x65leteTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs",\n\nListTimers\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x10ValueStateUpdate\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x07\n\x05\x43lear".\n\x0cListStateGet\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"L\n\x0cListStatePut\x12\x14\n\x05value\x18\x01 \x03(\x0cR\x05value\x12&\n\x0e\x66\x65tchWithArrow\x18\x02 \x01(\x08R\x0e\x66\x65tchWithArrow"#\n\x0b\x41ppendValue\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"J\n\nAppendList\x12\x14\n\x05value\x18\x01 \x03(\x0cR\x05value\x12&\n\x0e\x66\x65tchWithArrow\x18\x02 \x01(\x08R\x0e\x66\x65tchWithArrow"$\n\x08GetValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"\'\n\x0b\x43ontainsKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"=\n\x0bUpdateValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey\x12\x14\n\x05value\x18\x02 \x01(\x0cR\x05value"*\n\x08Iterator\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"&\n\x04Keys\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x06Values\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"%\n\tRemoveKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"c\n\x0eSetHandleState\x12Q\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleStateR\x05state"+\n\tTTLConfig\x12\x1e\n\ndurationMs\x18\x01 \x01(\x05R\ndurationMs*n\n\x0bHandleState\x12\x0c\n\x08PRE_INIT\x10\x00\x12\x0b\n\x07\x43REATED\x10\x01\x12\x0f\n\x0bINITIALIZED\x10\x02\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x03\x12\x13\n\x0fTIMER_PROCESSED\x10\x04\x12\n\n\x06\x43LOSED\x10\x05\x62\x06proto3'
)
_globals = globals()
@@ -50,8 +50,8 @@
)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
- _globals["_HANDLESTATE"]._serialized_start = 6408
- _globals["_HANDLESTATE"]._serialized_end = 6518
+ _globals["_HANDLESTATE"]._serialized_start = 6695
+ _globals["_HANDLESTATE"]._serialized_end = 6805
_globals["_STATEREQUEST"]._serialized_start = 112
_globals["_STATEREQUEST"]._serialized_end = 756
_globals["_STATERESPONSE"]._serialized_start = 758
@@ -60,78 +60,80 @@
_globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 985
_globals["_STATERESPONSEWITHSTRINGTYPEVAL"]._serialized_start = 987
_globals["_STATERESPONSEWITHSTRINGTYPEVAL"]._serialized_end = 1109
- _globals["_STATEFULPROCESSORCALL"]._serialized_start = 1112
- _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1784
- _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1787
- _globals["_STATEVARIABLEREQUEST"]._serialized_end = 2128
- _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 2131
- _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2390
- _globals["_TIMERREQUEST"]._serialized_start = 2393
- _globals["_TIMERREQUEST"]._serialized_end = 2650
- _globals["_TIMERVALUEREQUEST"]._serialized_start = 2653
- _globals["_TIMERVALUEREQUEST"]._serialized_end = 2899
- _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2901
- _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2967
- _globals["_GETPROCESSINGTIME"]._serialized_start = 2969
- _globals["_GETPROCESSINGTIME"]._serialized_end = 2988
- _globals["_GETWATERMARK"]._serialized_start = 2990
- _globals["_GETWATERMARK"]._serialized_end = 3004
- _globals["_UTILSREQUEST"]._serialized_start = 3007
- _globals["_UTILSREQUEST"]._serialized_end = 3146
- _globals["_PARSESTRINGSCHEMA"]._serialized_start = 3148
- _globals["_PARSESTRINGSCHEMA"]._serialized_end = 3191
- _globals["_STATECALLCOMMAND"]._serialized_start = 3194
- _globals["_STATECALLCOMMAND"]._serialized_end = 3393
- _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 3396
- _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3691
- _globals["_VALUESTATECALL"]._serialized_start = 3694
- _globals["_VALUESTATECALL"]._serialized_end = 4096
- _globals["_LISTSTATECALL"]._serialized_start = 4099
- _globals["_LISTSTATECALL"]._serialized_end = 4706
- _globals["_MAPSTATECALL"]._serialized_start = 4709
- _globals["_MAPSTATECALL"]._serialized_end = 5543
- _globals["_SETIMPLICITKEY"]._serialized_start = 5545
- _globals["_SETIMPLICITKEY"]._serialized_end = 5579
- _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5581
- _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5600
- _globals["_EXISTS"]._serialized_start = 5602
- _globals["_EXISTS"]._serialized_end = 5610
- _globals["_GET"]._serialized_start = 5612
- _globals["_GET"]._serialized_end = 5617
- _globals["_REGISTERTIMER"]._serialized_start = 5619
- _globals["_REGISTERTIMER"]._serialized_end = 5680
- _globals["_DELETETIMER"]._serialized_start = 5682
- _globals["_DELETETIMER"]._serialized_end = 5741
- _globals["_LISTTIMERS"]._serialized_start = 5743
- _globals["_LISTTIMERS"]._serialized_end = 5787
- _globals["_VALUESTATEUPDATE"]._serialized_start = 5789
- _globals["_VALUESTATEUPDATE"]._serialized_end = 5829
- _globals["_CLEAR"]._serialized_start = 5831
- _globals["_CLEAR"]._serialized_end = 5838
- _globals["_LISTSTATEGET"]._serialized_start = 5840
- _globals["_LISTSTATEGET"]._serialized_end = 5886
- _globals["_LISTSTATEPUT"]._serialized_start = 5888
- _globals["_LISTSTATEPUT"]._serialized_end = 5902
- _globals["_APPENDVALUE"]._serialized_start = 5904
- _globals["_APPENDVALUE"]._serialized_end = 5939
- _globals["_APPENDLIST"]._serialized_start = 5941
- _globals["_APPENDLIST"]._serialized_end = 5953
- _globals["_GETVALUE"]._serialized_start = 5955
- _globals["_GETVALUE"]._serialized_end = 5991
- _globals["_CONTAINSKEY"]._serialized_start = 5993
- _globals["_CONTAINSKEY"]._serialized_end = 6032
- _globals["_UPDATEVALUE"]._serialized_start = 6034
- _globals["_UPDATEVALUE"]._serialized_end = 6095
- _globals["_ITERATOR"]._serialized_start = 6097
- _globals["_ITERATOR"]._serialized_end = 6139
- _globals["_KEYS"]._serialized_start = 6141
- _globals["_KEYS"]._serialized_end = 6179
- _globals["_VALUES"]._serialized_start = 6181
- _globals["_VALUES"]._serialized_end = 6221
- _globals["_REMOVEKEY"]._serialized_start = 6223
- _globals["_REMOVEKEY"]._serialized_end = 6260
- _globals["_SETHANDLESTATE"]._serialized_start = 6262
- _globals["_SETHANDLESTATE"]._serialized_end = 6361
- _globals["_TTLCONFIG"]._serialized_start = 6363
- _globals["_TTLCONFIG"]._serialized_end = 6406
+ _globals["_STATERESPONSEWITHLISTGET"]._serialized_start = 1112
+ _globals["_STATERESPONSEWITHLISTGET"]._serialized_end = 1272
+ _globals["_STATEFULPROCESSORCALL"]._serialized_start = 1275
+ _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1947
+ _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1950
+ _globals["_STATEVARIABLEREQUEST"]._serialized_end = 2291
+ _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 2294
+ _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2553
+ _globals["_TIMERREQUEST"]._serialized_start = 2556
+ _globals["_TIMERREQUEST"]._serialized_end = 2813
+ _globals["_TIMERVALUEREQUEST"]._serialized_start = 2816
+ _globals["_TIMERVALUEREQUEST"]._serialized_end = 3062
+ _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 3064
+ _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 3130
+ _globals["_GETPROCESSINGTIME"]._serialized_start = 3132
+ _globals["_GETPROCESSINGTIME"]._serialized_end = 3151
+ _globals["_GETWATERMARK"]._serialized_start = 3153
+ _globals["_GETWATERMARK"]._serialized_end = 3167
+ _globals["_UTILSREQUEST"]._serialized_start = 3170
+ _globals["_UTILSREQUEST"]._serialized_end = 3309
+ _globals["_PARSESTRINGSCHEMA"]._serialized_start = 3311
+ _globals["_PARSESTRINGSCHEMA"]._serialized_end = 3354
+ _globals["_STATECALLCOMMAND"]._serialized_start = 3357
+ _globals["_STATECALLCOMMAND"]._serialized_end = 3556
+ _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 3559
+ _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3854
+ _globals["_VALUESTATECALL"]._serialized_start = 3857
+ _globals["_VALUESTATECALL"]._serialized_end = 4259
+ _globals["_LISTSTATECALL"]._serialized_start = 4262
+ _globals["_LISTSTATECALL"]._serialized_end = 4869
+ _globals["_MAPSTATECALL"]._serialized_start = 4872
+ _globals["_MAPSTATECALL"]._serialized_end = 5706
+ _globals["_SETIMPLICITKEY"]._serialized_start = 5708
+ _globals["_SETIMPLICITKEY"]._serialized_end = 5742
+ _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5744
+ _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5763
+ _globals["_EXISTS"]._serialized_start = 5765
+ _globals["_EXISTS"]._serialized_end = 5773
+ _globals["_GET"]._serialized_start = 5775
+ _globals["_GET"]._serialized_end = 5780
+ _globals["_REGISTERTIMER"]._serialized_start = 5782
+ _globals["_REGISTERTIMER"]._serialized_end = 5843
+ _globals["_DELETETIMER"]._serialized_start = 5845
+ _globals["_DELETETIMER"]._serialized_end = 5904
+ _globals["_LISTTIMERS"]._serialized_start = 5906
+ _globals["_LISTTIMERS"]._serialized_end = 5950
+ _globals["_VALUESTATEUPDATE"]._serialized_start = 5952
+ _globals["_VALUESTATEUPDATE"]._serialized_end = 5992
+ _globals["_CLEAR"]._serialized_start = 5994
+ _globals["_CLEAR"]._serialized_end = 6001
+ _globals["_LISTSTATEGET"]._serialized_start = 6003
+ _globals["_LISTSTATEGET"]._serialized_end = 6049
+ _globals["_LISTSTATEPUT"]._serialized_start = 6051
+ _globals["_LISTSTATEPUT"]._serialized_end = 6127
+ _globals["_APPENDVALUE"]._serialized_start = 6129
+ _globals["_APPENDVALUE"]._serialized_end = 6164
+ _globals["_APPENDLIST"]._serialized_start = 6166
+ _globals["_APPENDLIST"]._serialized_end = 6240
+ _globals["_GETVALUE"]._serialized_start = 6242
+ _globals["_GETVALUE"]._serialized_end = 6278
+ _globals["_CONTAINSKEY"]._serialized_start = 6280
+ _globals["_CONTAINSKEY"]._serialized_end = 6319
+ _globals["_UPDATEVALUE"]._serialized_start = 6321
+ _globals["_UPDATEVALUE"]._serialized_end = 6382
+ _globals["_ITERATOR"]._serialized_start = 6384
+ _globals["_ITERATOR"]._serialized_end = 6426
+ _globals["_KEYS"]._serialized_start = 6428
+ _globals["_KEYS"]._serialized_end = 6466
+ _globals["_VALUES"]._serialized_start = 6468
+ _globals["_VALUES"]._serialized_end = 6508
+ _globals["_REMOVEKEY"]._serialized_start = 6510
+ _globals["_REMOVEKEY"]._serialized_end = 6547
+ _globals["_SETHANDLESTATE"]._serialized_start = 6549
+ _globals["_SETHANDLESTATE"]._serialized_end = 6648
+ _globals["_TTLCONFIG"]._serialized_start = 6650
+ _globals["_TTLCONFIG"]._serialized_end = 6693
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
index ac4b03b820349..aa86826862bb1 100644
--- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
@@ -34,7 +34,9 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import builtins
+import collections.abc
import google.protobuf.descriptor
+import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
@@ -229,6 +231,44 @@ class StateResponseWithStringTypeVal(google.protobuf.message.Message):
global___StateResponseWithStringTypeVal = StateResponseWithStringTypeVal
+class StateResponseWithListGet(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ STATUSCODE_FIELD_NUMBER: builtins.int
+ ERRORMESSAGE_FIELD_NUMBER: builtins.int
+ VALUE_FIELD_NUMBER: builtins.int
+ REQUIRENEXTFETCH_FIELD_NUMBER: builtins.int
+ statusCode: builtins.int
+ errorMessage: builtins.str
+ @property
+ def value(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
+ requireNextFetch: builtins.bool
+ def __init__(
+ self,
+ *,
+ statusCode: builtins.int = ...,
+ errorMessage: builtins.str = ...,
+ value: collections.abc.Iterable[builtins.bytes] | None = ...,
+ requireNextFetch: builtins.bool = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "errorMessage",
+ b"errorMessage",
+ "requireNextFetch",
+ b"requireNextFetch",
+ "statusCode",
+ b"statusCode",
+ "value",
+ b"value",
+ ],
+ ) -> None: ...
+
+global___StateResponseWithListGet = StateResponseWithListGet
+
class StatefulProcessorCall(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -1042,8 +1082,24 @@ global___ListStateGet = ListStateGet
class ListStatePut(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ VALUE_FIELD_NUMBER: builtins.int
+ FETCHWITHARROW_FIELD_NUMBER: builtins.int
+ @property
+ def value(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
+ fetchWithArrow: builtins.bool
def __init__(
self,
+ *,
+ value: collections.abc.Iterable[builtins.bytes] | None = ...,
+ fetchWithArrow: builtins.bool = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "fetchWithArrow", b"fetchWithArrow", "value", b"value"
+ ],
) -> None: ...
global___ListStatePut = ListStatePut
@@ -1065,8 +1121,24 @@ global___AppendValue = AppendValue
class AppendList(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ VALUE_FIELD_NUMBER: builtins.int
+ FETCHWITHARROW_FIELD_NUMBER: builtins.int
+ @property
+ def value(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
+ fetchWithArrow: builtins.bool
def __init__(
self,
+ *,
+ value: collections.abc.Iterable[builtins.bytes] | None = ...,
+ fetchWithArrow: builtins.bool = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "fetchWithArrow", b"fetchWithArrow", "value", b"value"
+ ],
) -> None: ...
global___AppendList = AppendList
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index a7349779dc626..ab988eb714cc6 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -21,7 +21,7 @@
from typing import IO, Iterator, Tuple
from pyspark.accumulators import _accumulatorRegistry
-from pyspark.errors import IllegalArgumentException, PySparkAssertionError, PySparkRuntimeError
+from pyspark.errors import IllegalArgumentException, PySparkAssertionError
from pyspark.serializers import (
read_int,
write_int,
@@ -78,6 +78,7 @@ def partitions_func(
start_offset = json.loads(utf8_deserializer.loads(infile))
end_offset = json.loads(utf8_deserializer.loads(infile))
partitions = reader.partitions(start_offset, end_offset)
+
# Return the serialized partition values.
write_int(len(partitions), outfile)
for partition in partitions:
@@ -183,12 +184,6 @@ def main(infile: IO, outfile: IO) -> None:
},
)
outfile.flush()
- except Exception as e:
- error_msg = "data source {} throw exception: {}".format(data_source.name, e)
- raise PySparkRuntimeError(
- errorClass="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
- messageParameters={"msg": error_msg},
- )
finally:
reader.stop()
except BaseException as e:
@@ -209,9 +204,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
# Prevent the socket from timeout error when query trigger interval is large.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py
index d836ec4e2e82b..dbece47a93ceb 100644
--- a/python/pyspark/sql/streaming/stateful_processor.py
+++ b/python/pyspark/sql/streaming/stateful_processor.py
@@ -29,7 +29,7 @@
MapStateKeyValuePairIterator,
)
from pyspark.sql.streaming.value_state_client import ValueStateClient
-from pyspark.sql.types import StructType
+from pyspark.sql.types import StructType, Row
if TYPE_CHECKING:
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
@@ -359,6 +359,14 @@ class StatefulProcessor(ABC):
Class that represents the arbitrary stateful logic that needs to be provided by the user to
perform stateful manipulations on keyed streams.
+ NOTE: Type of input data and return are different by which method is called, such as:
+
+ `transformWithStateInPandas` - :class:`pandas.DataFrame`
+ `transformWithState` - :class:`pyspark.sql.Row`
+
+ and the implementation of this class must follow the described type assignment, which implies
+ an implementation has to be bound to a method.
+
.. versionadded:: 4.0.0
"""
@@ -380,25 +388,29 @@ def init(self, handle: StatefulProcessorHandle) -> None:
def handleInputRows(
self,
key: Any,
- rows: Iterator["PandasDataFrameLike"],
+ rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]],
timerValues: TimerValues,
- ) -> Iterator["PandasDataFrameLike"]:
+ ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
"""
Function that will allow users to interact with input data rows along with the grouping key.
- It should take parameters (key, Iterator[`pandas.DataFrame`]) and return another
- Iterator[`pandas.DataFrame`]. For each group, all columns are passed together as
- `pandas.DataFrame` to the function, and the returned `pandas.DataFrame` across all
- invocations are combined as a :class:`DataFrame`. Note that the function should not make a
- guess of the number of elements in the iterator. To process all data, the `handleInputRows`
- function needs to iterate all elements and process them. On the other hand, the
- `handleInputRows` function is not strictly required to iterate through all elements in the
- iterator if it intends to read a part of data.
+
+ Type of input data and return are different by which method is called, such as:
+
+ For `transformWithStateInPandas`, it should take parameters
+ (key, Iterator[`pandas.DataFrame`]) and return another Iterator[`pandas.DataFrame`].
+ For `transformWithState`, it should take parameters (key, Iterator[`pyspark.sql.Row`])
+ and return another Iterator[`pyspark.sql.Row`].
+
+ Note that the function should not make a guess of the number of elements in the iterator.
+ To process all data, the `handleInputRows` function needs to iterate all elements and
+ process them. On the other hand, the `handleInputRows` function is not strictly required
+ to iterate through all elements in the iterator if it intends to read a part of data.
Parameters
----------
key : Any
grouping key.
- rows : iterable of :class:`pandas.DataFrame`
+ rows : iterable of :class:`pandas.DataFrame` or iterable of :class:`pyspark.sql.Row`
iterator of input rows associated with grouping key
timerValues: TimerValues
Timer value for the current batch that process the input rows.
@@ -408,12 +420,17 @@ def handleInputRows(
def handleExpiredTimer(
self, key: Any, timerValues: TimerValues, expiredTimerInfo: ExpiredTimerInfo
- ) -> Iterator["PandasDataFrameLike"]:
+ ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
"""
Optional to implement. Will act return an empty iterator if not defined.
Function that will be invoked when a timer is fired for a given key. Users can choose to
evict state, register new timers and optionally provide output rows.
+ Type of return is different by which method is called, such as:
+
+ For `transformWithStateInPandas`, it should return Iterator[`pandas.DataFrame`].
+ For `transformWithState`, it should return Iterator[`pyspark.sql.Row`].
+
Parameters
----------
key : Any
@@ -426,7 +443,6 @@ def handleExpiredTimer(
"""
return iter([])
- @abstractmethod
def close(self) -> None:
"""
Function called as the last method that allows for users to perform any cleanup or teardown
@@ -435,18 +451,23 @@ def close(self) -> None:
...
def handleInitialState(
- self, key: Any, initialState: "PandasDataFrameLike", timerValues: TimerValues
+ self, key: Any, initialState: Union["PandasDataFrameLike", Row], timerValues: TimerValues
) -> None:
"""
Optional to implement. Will act as no-op if not defined or no initial state input.
- Function that will be invoked only in the first batch for users to process initial states.
+ Function that will be invoked only in the first batch for users to process initial states.
+
+ Type of initial state is different by which method is called, such as:
+
+ For `transformWithStateInPandas`, it should take `pandas.DataFrame`.
+ For `transformWithState`, it should take `pyspark.sql.Row`.
Parameters
----------
key : Any
grouping key.
- initialState: :class:`pandas.DataFrame`
- One dataframe in the initial state associated with the key.
+ initialState: :class:`pandas.DataFrame` or :class:`pyspark.sql.Row`
+ One dataframe/row in the initial state associated with the key.
timerValues: TimerValues
Timer value for the current batch that process the input rows.
Users can get the processing or event time timestamp from TimerValues.
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 6fd56481bc612..18330c4096fa6 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -49,11 +49,27 @@ class StatefulProcessorHandleState(Enum):
class StatefulProcessorApiClient:
def __init__(
- self, state_server_port: int, key_schema: StructType, is_driver: bool = False
+ self, state_server_port: Union[int, str], key_schema: StructType, is_driver: bool = False
) -> None:
self.key_schema = key_schema
- self._client_socket = socket.socket()
- self._client_socket.connect(("localhost", state_server_port))
+ if isinstance(state_server_port, str):
+ self._client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self._client_socket.connect(state_server_port)
+ else:
+ self._client_socket = socket.socket()
+ self._client_socket.connect(("localhost", state_server_port))
+
+ # SPARK-51667: We have a pattern of sending messages continuously from one side
+ # (Python -> JVM, and vice versa) before getting response from other side. Since most
+ # messages we are sending are small, this triggers the bad combination of Nagle's
+ # algorithm and delayed ACKs, which can cause a significant delay on the latency.
+ # See SPARK-51667 for more details on how this can be a problem.
+ #
+ # Disabling either would work, but it's more common to disable Nagle's algorithm; there
+ # is lot less reference to disabling delayed ACKs, while there are lots of resources to
+ # disable Nagle's algorithm.
+ self._client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
self.sockfile = self._client_socket.makefile(
"rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
)
@@ -409,6 +425,18 @@ def _receive_proto_message_with_string_value(self) -> Tuple[int, str, str]:
message.ParseFromString(bytes)
return message.statusCode, message.errorMessage, message.value
+ # The third return type is RepeatedScalarFieldContainer[bytes], which is protobuf's container
+ # type. We simplify it to Any here to avoid unnecessary complexity.
+ def _receive_proto_message_with_list_get(self) -> Tuple[int, str, Any, bool]:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ length = read_int(self.sockfile)
+ bytes = self.sockfile.read(length)
+ message = stateMessage.StateResponseWithListGet()
+ message.ParseFromString(bytes)
+
+ return message.statusCode, message.errorMessage, message.value, message.requireNextFetch
+
def _receive_str(self) -> str:
return self.utf8_deserializer.loads(self.sockfile)
@@ -455,6 +483,26 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None:
def _read_arrow_state(self) -> Any:
return self.serializer.load_stream(self.sockfile)
+ def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None:
+ for value in state:
+ bytes = self._serialize_to_bytes(schema, value)
+ length = len(bytes)
+ write_int(length, self.sockfile)
+ self.sockfile.write(bytes)
+
+ write_int(-1, self.sockfile)
+ self.sockfile.flush()
+
+ def _read_list_state(self) -> List[Any]:
+ data_array = []
+ while True:
+ length = read_int(self.sockfile)
+ if length < 0:
+ break
+ bytes = self.sockfile.read(length)
+ data_array.append(self._deserialize_from_bytes(bytes))
+ return data_array
+
# Parse a string schema into a StructType schema. This method will perform an API call to
# JVM side to parse the schema string.
def _parse_string_schema(self, schema: str) -> StructType:
diff --git a/python/pyspark/sql/streaming/stateful_processor_util.py b/python/pyspark/sql/streaming/stateful_processor_util.py
index fbc3093f87092..c0ff176eb9c90 100644
--- a/python/pyspark/sql/streaming/stateful_processor_util.py
+++ b/python/pyspark/sql/streaming/stateful_processor_util.py
@@ -17,7 +17,7 @@
from enum import Enum
import itertools
-from typing import Any, Iterator, Optional, TYPE_CHECKING
+from typing import Any, Iterator, Optional, TYPE_CHECKING, Union
from pyspark.sql.streaming.stateful_processor_api_client import (
StatefulProcessorApiClient,
StatefulProcessorHandleState,
@@ -28,18 +28,20 @@
StatefulProcessorHandle,
TimerValues,
)
+from pyspark.sql.types import Row
if TYPE_CHECKING:
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
-# This file places the utilities for transformWithStateInPandas; we have a separate file to avoid
-# putting internal classes to the stateful_processor.py file which contains public APIs.
+# This file places the utilities for transformWithState in PySpark (Row, and Pandas); we have
+# a separate file to avoid putting internal classes to the stateful_processor.py file which
+# contains public APIs.
-class TransformWithStateInPandasFuncMode(Enum):
+class TransformWithStateInPySparkFuncMode(Enum):
"""
- Internal mode for python worker UDF mode for transformWithStateInPandas; external mode are in
- `StatefulProcessorHandleState` for public use purposes.
+ Internal mode for python worker UDF mode for transformWithState in PySpark; external mode are
+ in `StatefulProcessorHandleState` for public use purposes.
"""
PROCESS_DATA = 1
@@ -48,10 +50,10 @@ class TransformWithStateInPandasFuncMode(Enum):
PRE_INIT = 4
-class TransformWithStateInPandasUdfUtils:
+class TransformWithStateInPySparkUdfUtils:
"""
- Internal Utility class used for python worker UDF for transformWithStateInPandas. This class is
- shared for both classic and spark connect mode.
+ Internal Utility class used for python worker UDF for transformWithState in PySpark. This class
+ is shared for both classic and spark connect mode.
"""
def __init__(self, stateful_processor: StatefulProcessor, time_mode: str):
@@ -61,11 +63,11 @@ def __init__(self, stateful_processor: StatefulProcessor, time_mode: str):
def transformWithStateUDF(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
- mode: TransformWithStateInPandasFuncMode,
+ mode: TransformWithStateInPySparkFuncMode,
key: Any,
- input_rows: Iterator["PandasDataFrameLike"],
- ) -> Iterator["PandasDataFrameLike"]:
- if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
+ input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]],
+ ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
+ if mode == TransformWithStateInPySparkFuncMode.PRE_INIT:
return self._handle_pre_init(stateful_processor_api_client)
handle = StatefulProcessorHandle(stateful_processor_api_client)
@@ -74,13 +76,13 @@ def transformWithStateUDF(
self._stateful_processor.init(handle)
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)
- if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
+ if mode == TransformWithStateInPySparkFuncMode.PROCESS_TIMER:
stateful_processor_api_client.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
result = self._handle_expired_timers(stateful_processor_api_client)
return result
- elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
+ elif mode == TransformWithStateInPySparkFuncMode.COMPLETE:
stateful_processor_api_client.set_handle_state(
StatefulProcessorHandleState.TIMER_PROCESSED
)
@@ -89,18 +91,18 @@ def transformWithStateUDF(
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])
else:
- # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
+ # mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA
result = self._handle_data_rows(stateful_processor_api_client, key, input_rows)
return result
def transformWithStateWithInitStateUDF(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
- mode: TransformWithStateInPandasFuncMode,
+ mode: TransformWithStateInPySparkFuncMode,
key: Any,
- input_rows: Iterator["PandasDataFrameLike"],
- initial_states: Optional[Iterator["PandasDataFrameLike"]] = None,
- ) -> Iterator["PandasDataFrameLike"]:
+ input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]],
+ initial_states: Optional[Union[Iterator["PandasDataFrameLike"], Iterator[Row]]] = None,
+ ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
"""
UDF for TWS operator with non-empty initial states. Possible input combinations
of inputRows and initialStates iterator:
@@ -113,7 +115,7 @@ def transformWithStateWithInitStateUDF(
- `initialStates` is None, while `inputRows` is not empty. This is not first batch.
`initialStates` is initialized to the positional value as None.
"""
- if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
+ if mode == TransformWithStateInPySparkFuncMode.PRE_INIT:
return self._handle_pre_init(stateful_processor_api_client)
handle = StatefulProcessorHandle(stateful_processor_api_client)
@@ -122,19 +124,19 @@ def transformWithStateWithInitStateUDF(
self._stateful_processor.init(handle)
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)
- if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
+ if mode == TransformWithStateInPySparkFuncMode.PROCESS_TIMER:
stateful_processor_api_client.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
result = self._handle_expired_timers(stateful_processor_api_client)
return result
- elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
+ elif mode == TransformWithStateInPySparkFuncMode.COMPLETE:
stateful_processor_api_client.remove_implicit_key()
self._stateful_processor.close()
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])
else:
- # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
+ # mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA
batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps(
self._time_mode
)
@@ -155,7 +157,7 @@ def transformWithStateWithInitStateUDF(
except StopIteration:
input_rows_empty = True
else:
- input_rows = itertools.chain([first], input_rows)
+ input_rows = itertools.chain([first], input_rows) # type: ignore
if not input_rows_empty:
result = self._handle_data_rows(stateful_processor_api_client, key, input_rows)
@@ -166,7 +168,7 @@ def transformWithStateWithInitStateUDF(
def _handle_pre_init(
self, stateful_processor_api_client: StatefulProcessorApiClient
- ) -> Iterator["PandasDataFrameLike"]:
+ ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
# Driver handle is different from the handle used on executors;
# On JVM side, we will use `DriverStatefulProcessorHandleImpl` for driver handle which
# will only be used for handling init() and get the state schema on the driver.
@@ -186,8 +188,8 @@ def _handle_data_rows(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
key: Any,
- input_rows: Optional[Iterator["PandasDataFrameLike"]] = None,
- ) -> Iterator["PandasDataFrameLike"]:
+ input_rows: Optional[Union[Iterator["PandasDataFrameLike"], Iterator[Row]]] = None,
+ ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
stateful_processor_api_client.set_implicit_key(key)
batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps(
@@ -206,7 +208,7 @@ def _handle_data_rows(
def _handle_expired_timers(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
- ) -> Iterator["PandasDataFrameLike"]:
+ ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps(
self._time_mode
)
diff --git a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
index 99d386f07b5b6..8d9bed7e61875 100644
--- a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
+++ b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
@@ -31,7 +31,7 @@
from typing import IO
from pyspark.worker_util import check_python_version
from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
-from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
+from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPySparkFuncMode
from pyspark.sql.types import StructType
if TYPE_CHECKING:
@@ -51,7 +51,7 @@ def main(infile: IO, outfile: IO) -> None:
def process(
processor: StatefulProcessorApiClient,
- mode: TransformWithStateInPandasFuncMode,
+ mode: TransformWithStateInPySparkFuncMode,
key: Any,
input: Iterator["PandasDataFrameLike"],
) -> None:
@@ -72,16 +72,18 @@ def process(
# This driver runner will only be used on the first batch of a query,
# and the following code block should be only run once for each query run
state_server_port = read_int(infile)
+ if state_server_port == -1:
+ state_server_port = utf8_deserializer.loads(infile)
key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
print(
- f"{log_name} received parameters for UDF. State server port: {state_server_port}, "
+ f"{log_name} received parameters for UDF. State server port/path: {state_server_port}, "
f"key schema: {key_schema}.\n"
)
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
process(
stateful_processor_api_client,
- TransformWithStateInPandasFuncMode.PRE_INIT,
+ TransformWithStateInPySparkFuncMode.PRE_INIT,
None,
iter([]),
)
@@ -94,9 +96,11 @@ def process(
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py b/python/pyspark/sql/tests/arrow/test_arrow.py
index 065f97fcf7c78..5a770a947889b 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow.py
@@ -45,14 +45,13 @@
NullType,
DayTimeIntervalType,
)
+from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
- ExamplePoint,
- ExamplePointUDT,
)
from pyspark.errors import ArithmeticException, PySparkTypeError, UnsupportedOperationException
from pyspark.loose_version import LooseVersion
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py b/python/pyspark/sql/tests/arrow/test_arrow_map.py
index 86dca7ed92d0b..fa2ce69c4fa53 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py
@@ -124,7 +124,7 @@ def test_empty_rows(self):
def empty_rows(_):
return iter([pa.RecordBatch.from_pandas(pd.DataFrame({"a": []}))])
- self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a int").count(), 0)
+ self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a double").count(), 0)
def test_chain_map_in_arrow(self):
def func(iterator):
@@ -175,6 +175,44 @@ def test_negative_and_zero_batch_size(self):
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
MapInArrowTests.test_map_in_arrow(self)
+ def test_nested_extraneous_field(self):
+ def func(iterator):
+ for _ in iterator:
+ struct_arr = pa.StructArray.from_arrays([[1, 2], [3, 4]], names=["a", "b"])
+ yield pa.RecordBatch.from_arrays([struct_arr], ["x"])
+
+ df = self.spark.range(1)
+ with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
+ df.mapInArrow(func, "x struct").collect()
+
+ def test_top_level_wrong_order(self):
+ def func(iterator):
+ for _ in iterator:
+ yield pa.RecordBatch.from_arrays([[1], [2]], ["b", "a"])
+
+ df = self.spark.range(1)
+ with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
+ df.mapInArrow(func, "a int, b int").collect()
+
+ def test_nullability_widen(self):
+ def func(iterator):
+ for _ in iterator:
+ yield pa.RecordBatch.from_arrays([[1]], ["a"])
+
+ df = self.spark.range(1)
+ with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
+ df.mapInArrow(func, "a int not null").collect()
+
+ def test_nullability_narrow(self):
+ def func(iterator):
+ for _ in iterator:
+ yield pa.RecordBatch.from_arrays(
+ [[1]], pa.schema([pa.field("a", pa.int32(), nullable=False)])
+ )
+
+ df = self.spark.range(1)
+ df.mapInArrow(func, "a int").collect()
+
class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
@@ -208,6 +246,14 @@ def setUpClass(cls):
cls.spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "10")
+class MapInArrowWithOutputArrowBatchSlicingTests(MapInArrowTests):
+ @classmethod
+ def setUpClass(cls):
+ MapInArrowTests.setUpClass()
+ cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10")
+ cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerOutputBatch", "3")
+
+
if __name__ == "__main__":
from pyspark.sql.tests.arrow.test_arrow_map import * # noqa: F401
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
index 78806ad399ab7..9892bc1f9c50c 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
@@ -35,18 +35,18 @@
@unittest.skipIf(
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
)
-class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
+class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin):
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_broadcast_in_udf(self):
- super(PythonUDFArrowTests, self).test_broadcast_in_udf()
+ super(ArrowPythonUDFTests, self).test_broadcast_in_udf()
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_register_java_function(self):
- super(PythonUDFArrowTests, self).test_register_java_function()
+ super(ArrowPythonUDFTests, self).test_register_java_function()
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_register_java_udaf(self):
- super(PythonUDFArrowTests, self).test_register_java_udaf()
+ super(ArrowPythonUDFTests, self).test_register_java_udaf()
def test_complex_input_types(self):
row = (
@@ -214,11 +214,18 @@ def test_udf(a, b):
with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show()
+ def test_udf_with_udt(self):
+ for fallback in [False, True]:
+ with self.subTest(fallback=fallback), self.sql_conf(
+ {"spark.sql.execution.pythonUDF.arrow.legacy.fallbackOnUDT": fallback}
+ ):
+ super().test_udf_with_udt()
-class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
+
+class ArrowPythonUDFTests(ArrowPythonUDFTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(PythonUDFArrowTests, cls).setUpClass()
+ super(ArrowPythonUDFTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
@classmethod
@@ -226,13 +233,13 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
- super(PythonUDFArrowTests, cls).tearDownClass()
+ super(ArrowPythonUDFTests, cls).tearDownClass()
-class AsyncPythonUDFArrowTests(PythonUDFArrowTests):
+class AsyncArrowPythonUDFTests(ArrowPythonUDFTests):
@classmethod
def setUpClass(cls):
- super(AsyncPythonUDFArrowTests, cls).setUpClass()
+ super(AsyncArrowPythonUDFTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", "4")
@classmethod
@@ -240,7 +247,7 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.concurrency.level")
finally:
- super(AsyncPythonUDFArrowTests, cls).tearDownClass()
+ super(AsyncArrowPythonUDFTests, cls).tearDownClass()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
index fe81513f005f9..8a5fe6131bd3d 100644
--- a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
@@ -16,10 +16,10 @@
#
from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
-from pyspark.sql.tests.arrow.test_arrow_python_udf import PythonUDFArrowTestsMixin
+from pyspark.sql.tests.arrow.test_arrow_python_udf import ArrowPythonUDFTestsMixin
-class ArrowPythonUDFParityTests(UDFParityTests, PythonUDFArrowTestsMixin):
+class ArrowPythonUDFParityTests(UDFParityTests, ArrowPythonUDFTestsMixin):
@classmethod
def setUpClass(cls):
super(ArrowPythonUDFParityTests, cls).setUpClass()
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py
index 3066d935131c5..22865be4f42a6 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -224,6 +224,8 @@ def conf(cls):
def test_basic_requests(self):
file_name = "smallJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
+ if not os.path.isfile(small_jar_path):
+ raise unittest.SkipTest(f"Skipped as {small_jar_path} does not exist.")
response = self.artifact_manager._retrieve_responses(
self.artifact_manager._create_requests(
small_jar_path, pyfile=False, archive=False, file=False
@@ -235,6 +237,8 @@ def test_single_chunk_artifact(self):
file_name = "smallJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
+ if not os.path.isfile(small_jar_path):
+ raise unittest.SkipTest(f"Skipped as {small_jar_path} does not exist.")
requests = list(
self.artifact_manager._create_requests(
@@ -261,6 +265,8 @@ def test_chunked_artifacts(self):
file_name = "junitLargeJar"
large_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
+ if not os.path.isfile(large_jar_path):
+ raise unittest.SkipTest(f"Skipped as {large_jar_path} does not exist.")
requests = list(
self.artifact_manager._create_requests(
@@ -296,6 +302,8 @@ def test_batched_artifacts(self):
file_name = "smallJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
+ if not os.path.isfile(small_jar_path):
+ raise unittest.SkipTest(f"Skipped as {small_jar_path} does not exist.")
requests = list(
self.artifact_manager._create_requests(
@@ -333,6 +341,10 @@ def test_single_chunked_and_chunked_artifact(self):
large_jar_path = os.path.join(self.artifact_file_path, f"{file_name2}.jar")
large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name2}.txt")
large_jar_size = os.path.getsize(large_jar_path)
+ if not os.path.isfile(small_jar_path):
+ raise unittest.SkipTest(f"Skipped as {small_jar_path} does not exist.")
+ if not os.path.isfile(large_jar_path):
+ raise unittest.SkipTest(f"Skipped as {large_jar_path} does not exist.")
requests = list(
self.artifact_manager._create_requests(
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py
index 741d6b9c1104e..3ab73adfcea5e 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -137,6 +137,7 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
self.req = req
resp = proto.ExecutePlanResponse()
resp.session_id = self._session_id
+ resp.operation_id = req.operation_id
pdf = pd.DataFrame(data={"col1": [1, 2]})
schema = pa.Schema.from_pandas(pdf)
@@ -255,6 +256,16 @@ def test_channel_builder_with_session(self):
client = SparkConnectClient(chan)
self.assertEqual(client._session_id, chan.session_id)
+ def test_custom_operation_id(self):
+ client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
+ mock = MockService(client._session_id)
+ client._stub = mock
+ req = client._execute_plan_request_with_metadata(
+ operation_id="10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
+ )
+ for resp in client._stub.ExecutePlan(req, metadata=None):
+ assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectClientReattachTestCase(unittest.TestCase):
@@ -404,6 +415,7 @@ def not_found():
pass
self.assertTrue("RESPONSE_ALREADY_RECEIVED" in e.exception.getMessage())
+ self.assertTrue(error_code in e.exception.getMessage())
def checks():
self.assertEqual(1, stub.execute_calls)
diff --git a/python/pyspark/sql/tests/connect/client/test_reattach.py b/python/pyspark/sql/tests/connect/client/test_reattach.py
index 64c81529ec141..18d9a62bf46e9 100644
--- a/python/pyspark/sql/tests/connect/client/test_reattach.py
+++ b/python/pyspark/sql/tests/connect/client/test_reattach.py
@@ -15,39 +15,17 @@
# limitations under the License.
#
-import os
import unittest
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession as PySparkSession
-from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.connectutils import ReusedMixedTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.utils import eventually
@unittest.skipIf(is_remote_only(), "Requires JVM access")
-class SparkConnectReattachTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils):
- @classmethod
- def setUpClass(cls):
- super(SparkConnectReattachTestCase, cls).setUpClass()
- # Disable the shared namespace so pyspark.sql.functions, etc point the regular
- # PySpark libraries.
- os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
-
- cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
- cls.spark = PySparkSession._instantiatedSession
- assert cls.spark is not None
-
- @classmethod
- def tearDownClass(cls):
- try:
- # Stopping Spark Connect closes the session in JVM at the server.
- cls.spark = cls.connect
- del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
- finally:
- super(SparkConnectReattachTestCase, cls).tearDownClass()
-
+class SparkConnectReattachTestCase(ReusedMixedTestCase, PandasOnSparkTestUtils):
def test_release_sessions(self):
big_enough_query = "select * from range(1000000)"
query1 = self.connect.sql(big_enough_query).toLocalIterator()
diff --git a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
index fb25c448cef0a..e772c2139326f 100644
--- a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
@@ -18,6 +18,7 @@
from pyspark.sql.tests.pandas.test_pandas_transform_with_state import (
TransformWithStateInPandasTestsMixin,
+ TransformWithStateInPySparkTestsMixin,
)
from pyspark import SparkConf
from pyspark.testing.connectutils import ReusedConnectTestCase
@@ -48,6 +49,40 @@ def conf(cls):
return cfg
+ @unittest.skip("Flaky in spark connect on CI. Skip for now. See SPARK-51368 for details.")
+ def test_schema_evolution_scenarios(self):
+ pass
+
+
+class TransformWithStateInPySparkParityTests(
+ TransformWithStateInPySparkTestsMixin, ReusedConnectTestCase
+):
+ """
+ Spark connect parity tests for TransformWithStateInPySpark. Run every test case in
+ `TransformWithStateInPySparkTestsMixin` in spark connect mode.
+ """
+
+ @classmethod
+ def conf(cls):
+ # Due to multiple inheritance from the same level, we need to explicitly setting configs in
+ # both TransformWithStateInPySparkTestsMixin and ReusedConnectTestCase here
+ cfg = SparkConf(loadDefaults=False)
+ for base in cls.__bases__:
+ if hasattr(base, "conf"):
+ parent_cfg = base.conf()
+ for k, v in parent_cfg.getAll():
+ cfg.set(k, v)
+
+ # Extra removing config for connect suites
+ if cfg._jconf is not None:
+ cfg._jconf.remove("spark.master")
+
+ return cfg
+
+ @unittest.skip("Flaky in spark connect on CI. Skip for now. See SPARK-51368 for details.")
+ def test_schema_evolution_scenarios(self):
+ pass
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state import * # noqa: F401,E501
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f0637056ab8f9..2aa383f39937e 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -26,7 +26,6 @@
from pyspark.util import is_remote_only
from pyspark.errors import PySparkTypeError, PySparkValueError
-from pyspark.sql import SparkSession as PySparkSession, Row
from pyspark.sql.types import (
StructType,
StructField,
@@ -38,16 +37,13 @@
Row,
)
from pyspark.testing.utils import eventually
-from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.connectutils import (
should_test_connect,
- ReusedConnectTestCase,
+ connect_requirement_message,
+ ReusedMixedTestCase,
)
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-from pyspark.errors.exceptions.connect import (
- AnalysisException,
- SparkConnectException,
-)
+
if should_test_connect:
from pyspark.sql.connect.proto import Expression as ProtoExpression
@@ -56,23 +52,20 @@
from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
+ from pyspark.errors.exceptions.connect import AnalysisException, SparkConnectException
-@unittest.skipIf(is_remote_only(), "Requires JVM access")
-class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils):
+@unittest.skipIf(
+ not should_test_connect or is_remote_only(),
+ connect_requirement_message or "Requires JVM access",
+)
+class SparkConnectSQLTestCase(ReusedMixedTestCase, PandasOnSparkTestUtils):
"""Parent test fixture class for all Spark Connect related
test cases."""
@classmethod
def setUpClass(cls):
super(SparkConnectSQLTestCase, cls).setUpClass()
- # Disable the shared namespace so pyspark.sql.functions, etc point the regular
- # PySpark libraries.
- os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
-
- cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
- cls.spark = PySparkSession._instantiatedSession
- assert cls.spark is not None
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
cls.testDataStr = [Row(key=str(i)) for i in range(100)]
@@ -94,9 +87,6 @@ def setUpClass(cls):
def tearDownClass(cls):
try:
cls.spark_connect_clean_up_test_data()
- # Stopping Spark Connect closes the session in JVM at the server.
- cls.spark = cls.connect
- del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
finally:
super(SparkConnectSQLTestCase, cls).tearDownClass()
diff --git a/python/pyspark/sql/tests/connect/test_connect_channel.py b/python/pyspark/sql/tests/connect/test_connect_channel.py
new file mode 100644
index 0000000000000..92b418f7bc0c4
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_connect_channel.py
@@ -0,0 +1,163 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import uuid
+
+from pyspark.errors import PySparkValueError
+from pyspark.testing.connectutils import (
+ should_test_connect,
+ connect_requirement_message,
+)
+
+if should_test_connect:
+ import grpc
+ from pyspark.sql.connect.client import DefaultChannelBuilder, ChannelBuilder
+ from pyspark.sql.connect.client.core import SparkConnectClient
+ from pyspark.errors.exceptions.connect import SparkConnectException
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
+class ChannelBuilderTests(unittest.TestCase):
+ def test_invalid_connection_strings(self):
+ invalid = [
+ "scc://host:12",
+ "http://host",
+ "sc:/host:1234/path",
+ "sc://host/path",
+ "sc://host/;parm1;param2",
+ ]
+ for i in invalid:
+ self.assertRaises(PySparkValueError, DefaultChannelBuilder, i)
+
+ def test_sensible_defaults(self):
+ chan = DefaultChannelBuilder("sc://host")
+ self.assertFalse(chan.secure, "Default URL is not secure")
+
+ chan = DefaultChannelBuilder("sc://host/;token=abcs")
+ self.assertTrue(chan.secure, "specifying a token must set the channel to secure")
+ self.assertRegex(
+ chan.userAgent, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$"
+ )
+ chan = DefaultChannelBuilder("sc://host/;use_ssl=abcs")
+ self.assertFalse(chan.secure, "Garbage in, false out")
+
+ def test_user_agent(self):
+ chan = DefaultChannelBuilder("sc://host/;user_agent=Agent123%20%2F3.4")
+ self.assertIn("Agent123 /3.4", chan.userAgent)
+
+ def test_user_agent_len(self):
+ user_agent = "x" * 2049
+ chan = DefaultChannelBuilder(f"sc://host/;user_agent={user_agent}")
+ with self.assertRaises(SparkConnectException) as err:
+ chan.userAgent
+ self.assertRegex(err.exception._message, "'user_agent' parameter should not exceed")
+
+ user_agent = "%C3%A4" * 341 # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048
+ expected = "ä" * 341
+ chan = DefaultChannelBuilder(f"sc://host/;user_agent={user_agent}")
+ self.assertIn(expected, chan.userAgent)
+
+ def test_valid_channel_creation(self):
+ chan = DefaultChannelBuilder("sc://host").toChannel()
+ self.assertIsInstance(chan, grpc.Channel)
+
+ # Sets up a channel without tokens because ssl is not used.
+ chan = DefaultChannelBuilder("sc://host/;use_ssl=true;token=abc").toChannel()
+ self.assertIsInstance(chan, grpc.Channel)
+
+ chan = DefaultChannelBuilder("sc://host/;use_ssl=true").toChannel()
+ self.assertIsInstance(chan, grpc.Channel)
+
+ def test_channel_properties(self):
+ chan = DefaultChannelBuilder(
+ "sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021"
+ )
+ self.assertEqual("host:15002", chan.endpoint)
+ self.assertIn("foo", chan.userAgent.split(" "))
+ self.assertEqual(True, chan.secure)
+ self.assertEqual("120 21", chan.get("param1"))
+
+ def test_metadata(self):
+ chan = DefaultChannelBuilder(
+ "sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd"
+ )
+ md = chan.metadata()
+ self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md)
+
+ def test_metadata_with_session_id(self):
+ id = str(uuid.uuid4())
+ chan = DefaultChannelBuilder(f"sc://host/;session_id={id}")
+ self.assertEqual(id, chan.session_id)
+
+ chan = DefaultChannelBuilder(
+ f"sc://host/;session_id={id};user_agent=acbd;token=abcd;use_ssl=true"
+ )
+ md = chan.metadata()
+ for kv in md:
+ self.assertNotIn(
+ kv[0],
+ [
+ ChannelBuilder.PARAM_SESSION_ID,
+ ChannelBuilder.PARAM_TOKEN,
+ ChannelBuilder.PARAM_USER_ID,
+ ChannelBuilder.PARAM_USER_AGENT,
+ ChannelBuilder.PARAM_USE_SSL,
+ ],
+ "Metadata must not contain fixed params",
+ )
+
+ with self.assertRaises(ValueError) as ve:
+ chan = DefaultChannelBuilder("sc://host/;session_id=abcd")
+ SparkConnectClient(chan)
+ self.assertIn("Parameter value session_id must be a valid UUID format", str(ve.exception))
+
+ chan = DefaultChannelBuilder("sc://host/")
+ self.assertIsNone(chan.session_id)
+
+ def test_channel_options(self):
+ # SPARK-47694
+ chan = DefaultChannelBuilder(
+ "sc://host", [("grpc.max_send_message_length", 1860), ("test", "robert")]
+ )
+ options = chan._channel_options
+ self.assertEqual(
+ [k for k, _ in options].count("grpc.max_send_message_length"),
+ 1,
+ "only one occurrence for defaults",
+ )
+ self.assertEqual(
+ next(v for k, v in options if k == "grpc.max_send_message_length"),
+ 1860,
+ "overwrites defaults",
+ )
+ self.assertEqual(
+ next(v for k, v in options if k == "test"), "robert", "new values are picked up"
+ )
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_connect_channel import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_connect_creation.py b/python/pyspark/sql/tests/connect/test_connect_creation.py
index 163a1c17bfafa..3d67c33a58349 100644
--- a/python/pyspark/sql/tests/connect/test_connect_creation.py
+++ b/python/pyspark/sql/tests/connect/test_connect_creation.py
@@ -32,7 +32,7 @@
ArrayType,
Row,
)
-from pyspark.testing.sqlutils import MyObject, PythonOnlyUDT
+from pyspark.testing.objects import MyObject, PythonOnlyUDT
from pyspark.testing.connectutils import should_test_connect
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
@@ -219,11 +219,6 @@ def test_with_atom_type(self):
self.assert_eq(sdf.toPandas(), cdf.toPandas())
def test_with_none_and_nan(self):
- # TODO(SPARK-51286): Fix test_with_none_and_nan to to pass with Arrow enabled
- with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
- self.check_with_none_and_nan()
-
- def check_with_none_and_nan(self):
# SPARK-41855: make createDataFrame support None and NaN
# SPARK-41814: test with eqNullSafe
data1 = [Row(id=1, value=float("NaN")), Row(id=2, value=42.0), Row(id=3, value=None)]
diff --git a/python/pyspark/sql/tests/connect/test_connect_error.py b/python/pyspark/sql/tests/connect/test_connect_error.py
index 01047741f6740..47963e89471c2 100644
--- a/python/pyspark/sql/tests/connect/test_connect_error.py
+++ b/python/pyspark/sql/tests/connect/test_connect_error.py
@@ -20,44 +20,43 @@
from pyspark.errors import PySparkAttributeError
from pyspark.errors.exceptions.base import SessionNotSameException
from pyspark.sql.types import Row
-from pyspark.testing.connectutils import should_test_connect
+from pyspark.sql import functions as F
from pyspark.errors import PySparkTypeError
-from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.util import is_remote_only
-if should_test_connect:
- from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
- from pyspark.sql.connect import functions as CF
- from pyspark.sql.connect.column import Column
- from pyspark.errors.exceptions.connect import AnalysisException
-
-class SparkConnectErrorTests(SparkConnectSQLTestCase):
+class SparkConnectErrorTests(ReusedConnectTestCase):
def test_recursion_handling_for_plan_logging(self):
"""SPARK-45852 - Test that we can handle recursion in plan logging."""
- cdf = self.connect.range(1)
+ cdf = self.spark.range(1)
for x in range(400):
- cdf = cdf.withColumn(f"col_{x}", CF.lit(x))
+ cdf = cdf.withColumn(f"col_{x}", F.lit(x))
# Calling schema will trigger logging the message that will in turn trigger the message
# conversion into protobuf that will then trigger the recursion error.
self.assertIsNotNone(cdf.schema)
- result = self.connect._client._proto_to_string(cdf._plan.to_proto(self.connect._client))
+ result = self.spark._client._proto_to_string(cdf._plan.to_proto(self.spark._client))
self.assertIn("recursion", result)
def test_error_handling(self):
+ from pyspark.errors.exceptions.connect import AnalysisException
+
# SPARK-41533 Proper error handling for Spark Connect
- df = self.connect.range(10).select("id2")
+ df = self.spark.range(10).select("id2")
with self.assertRaises(AnalysisException):
df.collect()
def test_invalid_column(self):
+ from pyspark.errors.exceptions.connect import AnalysisException
+
# SPARK-41812: fail df1.select(df2.col)
data1 = [Row(a=1, b=2, c=3)]
- cdf1 = self.connect.createDataFrame(data1)
+ cdf1 = self.spark.createDataFrame(data1)
data2 = [Row(a=2, b=0)]
- cdf2 = self.connect.createDataFrame(data2)
+ cdf2 = self.spark.createDataFrame(data2)
with self.assertRaises(AnalysisException):
cdf1.select(cdf2.a).schema
@@ -81,11 +80,13 @@ def test_invalid_column(self):
cdf1.select(cdf2.a).schema
def test_invalid_star(self):
+ from pyspark.errors.exceptions.connect import AnalysisException
+
data1 = [Row(a=1, b=2, c=3)]
- cdf1 = self.connect.createDataFrame(data1)
+ cdf1 = self.spark.createDataFrame(data1)
data2 = [Row(a=2, b=0)]
- cdf2 = self.connect.createDataFrame(data2)
+ cdf2 = self.spark.createDataFrame(data2)
# Can find the target plan node, but fail to resolve with it
with self.assertRaisesRegex(
@@ -101,7 +102,7 @@ def test_invalid_star(self):
"CANNOT_RESOLVE_DATAFRAME_COLUMN",
):
# column 'a has been replaced
- cdf3 = cdf1.withColumn("a", CF.lit(0))
+ cdf3 = cdf1.withColumn("a", F.lit(0))
cdf3.select(cdf1["*"]).schema
# Can not find the target plan node by plan id
@@ -119,15 +120,24 @@ def test_invalid_star(self):
cdf1.join(cdf1).select(cdf1["*"]).schema
def test_deduplicate_within_watermark_in_batch(self):
- df = self.connect.read.table(self.tbl_name)
- with self.assertRaisesRegex(
- AnalysisException,
- "dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets",
- ):
- df.dropDuplicatesWithinWatermark().toPandas()
+ from pyspark.errors.exceptions.connect import AnalysisException
+
+ table_name = "tmp_table_for_test_deduplicate_within_watermark_in_batch"
+ with self.table(table_name):
+ self.spark.createDataFrame(
+ [Row(key=i, value=str(i)) for i in range(100)]
+ ).write.saveAsTable(table_name)
+
+ with self.assertRaisesRegex(
+ AnalysisException,
+ "dropDuplicatesWithinWatermark is not supported with batch DataFrames/DataSets",
+ ):
+ self.spark.read.table(table_name).dropDuplicatesWithinWatermark().toPandas()
def test_different_spark_session_join_or_union(self):
- df = self.connect.range(10).limit(3)
+ from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
+
+ df = self.spark.range(10).limit(3)
spark2 = RemoteSparkSession(connection="sc://localhost")
df2 = spark2.range(10).limit(3)
@@ -156,9 +166,10 @@ def test_different_spark_session_join_or_union(self):
messageParameters={},
)
+ @unittest.skipIf(is_remote_only(), "Disabled for remote only")
def test_unsupported_functions(self):
# SPARK-41225: Disable unsupported functions.
- df = self.connect.read.table(self.tbl_name)
+ df = self.spark.range(10)
with self.assertRaises(NotImplementedError):
df.toJSON()
with self.assertRaises(NotImplementedError):
@@ -167,7 +178,7 @@ def test_unsupported_functions(self):
def test_unsupported_jvm_attribute(self):
# Unsupported jvm attributes for Spark session.
unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"]
- spark_session = self.connect
+ spark_session = self.spark
for attr in unsupported_attrs:
with self.assertRaises(PySparkAttributeError) as pe:
getattr(spark_session, attr)
@@ -180,7 +191,7 @@ def test_unsupported_jvm_attribute(self):
# Unsupported jvm attributes for DataFrame.
unsupported_attrs = ["_jseq", "_jdf", "_jmap", "_jcols"]
- cdf = self.connect.range(10)
+ cdf = self.spark.range(10)
for attr in unsupported_attrs:
with self.assertRaises(PySparkAttributeError) as pe:
getattr(cdf, attr)
@@ -212,12 +223,14 @@ def test_unsupported_jvm_attribute(self):
)
def test_column_cannot_be_constructed_from_string(self):
+ from pyspark.sql.connect.column import Column
+
with self.assertRaises(TypeError):
Column("col")
def test_select_none(self):
with self.assertRaises(PySparkTypeError) as e1:
- self.connect.range(1).select(None)
+ self.spark.range(1).select(None)
self.check_error(
exception=e1.exception,
@@ -228,7 +241,7 @@ def test_select_none(self):
def test_ym_interval_in_collect(self):
# YearMonthIntervalType is not supported in python side arrow conversion
with self.assertRaises(PySparkTypeError):
- self.connect.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval").first()
+ self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS interval").first()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index d1e2558305291..20ce6b88e390a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -14,13 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import os
import unittest
from inspect import getmembers, isfunction
from pyspark.util import is_remote_only
from pyspark.errors import PySparkTypeError, PySparkValueError
-from pyspark.sql import SparkSession as PySparkSession
from pyspark.sql.types import (
_drop_metadata,
StringType,
@@ -31,9 +29,7 @@
)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect
-from pyspark.testing.sqlutils import SQLTestUtils
-from pyspark.errors.exceptions.connect import AnalysisException, SparkConnectException
+from pyspark.testing.connectutils import ReusedMixedTestCase, should_test_connect
if should_test_connect:
from pyspark.sql.connect.column import Column
@@ -41,47 +37,14 @@
from pyspark.sql.window import Window as SW
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.window import Window as CW
+ from pyspark.errors.exceptions.connect import AnalysisException, SparkConnectException
@unittest.skipIf(is_remote_only(), "Requires JVM access")
-class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils):
+class SparkConnectFunctionTests(ReusedMixedTestCase, PandasOnSparkTestUtils):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
- @classmethod
- def setUpClass(cls):
- super(SparkConnectFunctionTests, cls).setUpClass()
- # Disable the shared namespace so pyspark.sql.functions, etc point the regular
- # PySpark libraries.
- os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
- cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
- cls.spark = PySparkSession._instantiatedSession
- assert cls.spark is not None
-
- @classmethod
- def tearDownClass(cls):
- cls.spark = cls.connect # Stopping Spark Connect closes the session in JVM at the server.
- super(SparkConnectFunctionTests, cls).setUpClass()
- del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
-
- def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
- from pyspark.sql.classic.dataframe import DataFrame as SDF
- from pyspark.sql.connect.dataframe import DataFrame as CDF
-
- assert isinstance(df1, (SDF, CDF))
- if isinstance(df1, SDF):
- str1 = df1._jdf.showString(n, truncate, False)
- else:
- str1 = df1._show_string(n, truncate, False)
-
- assert isinstance(df2, (SDF, CDF))
- if isinstance(df2, SDF):
- str2 = df2._jdf.showString(n, truncate, False)
- else:
- str2 = df2._show_string(n, truncate, False)
-
- self.assertEqual(str1, str2)
-
def test_count_star(self):
# SPARK-42099: test count(*), count(col(*)) and count(expr(*))
data = [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")]
@@ -2609,7 +2572,6 @@ def test_non_deterministic_with_seed(self):
if __name__ == "__main__":
- import os
from pyspark.sql.tests.connect.test_connect_function import * # noqa: F401
try:
diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py
index 06266b86de3ff..dc82d93f9581e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py
+++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py
@@ -30,7 +30,7 @@
MapType,
Row,
)
-from pyspark.testing.sqlutils import (
+from pyspark.testing.objects import (
PythonOnlyUDT,
ExamplePoint,
PythonOnlyPoint,
diff --git a/python/pyspark/sql/tests/connect/test_connect_retry.py b/python/pyspark/sql/tests/connect/test_connect_retry.py
new file mode 100644
index 0000000000000..f51e062479284
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_connect_retry.py
@@ -0,0 +1,179 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from collections import defaultdict
+
+from pyspark.errors import RetriesExceeded
+from pyspark.testing.connectutils import (
+ should_test_connect,
+ connect_requirement_message,
+)
+
+if should_test_connect:
+ import grpc
+ from pyspark.sql.connect.client.core import Retrying
+ from pyspark.sql.connect.client.retries import RetryPolicy
+
+
+if should_test_connect:
+
+ class TestError(grpc.RpcError, Exception):
+ def __init__(self, code: grpc.StatusCode):
+ self._code = code
+
+ def code(self):
+ return self._code
+
+ class TestPolicy(RetryPolicy):
+ # Put a small value for initial backoff so that tests don't spend
+ # Time waiting
+ def __init__(self, initial_backoff=10, **kwargs):
+ super().__init__(initial_backoff=initial_backoff, **kwargs)
+
+ def can_retry(self, exception: BaseException):
+ return isinstance(exception, TestError)
+
+ class TestPolicySpecificError(TestPolicy):
+ def __init__(self, specific_code: grpc.StatusCode, **kwargs):
+ super().__init__(**kwargs)
+ self.specific_code = specific_code
+
+ def can_retry(self, exception: BaseException):
+ return exception.code() == self.specific_code
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
+class RetryTests(unittest.TestCase):
+ def setUp(self) -> None:
+ self.call_wrap = defaultdict(int)
+
+ def stub(self, retries, code):
+ self.call_wrap["attempts"] += 1
+ if self.call_wrap["attempts"] < retries:
+ self.call_wrap["raised"] += 1
+ raise TestError(code)
+
+ def test_simple(self):
+ # Check that max_retries 1 is only one retry so two attempts.
+ for attempt in Retrying(TestPolicy(max_retries=1)):
+ with attempt:
+ self.stub(2, grpc.StatusCode.INTERNAL)
+
+ self.assertEqual(2, self.call_wrap["attempts"])
+ self.assertEqual(1, self.call_wrap["raised"])
+
+ def test_below_limit(self):
+ # Check that if we have less than 4 retries all is ok.
+ for attempt in Retrying(TestPolicy(max_retries=4)):
+ with attempt:
+ self.stub(2, grpc.StatusCode.INTERNAL)
+
+ self.assertLess(self.call_wrap["attempts"], 4)
+ self.assertEqual(self.call_wrap["raised"], 1)
+
+ def test_exceed_retries(self):
+ # Exceed the retries.
+ with self.assertRaises(RetriesExceeded):
+ for attempt in Retrying(TestPolicy(max_retries=2)):
+ with attempt:
+ self.stub(5, grpc.StatusCode.INTERNAL)
+
+ self.assertLess(self.call_wrap["attempts"], 5)
+ self.assertEqual(self.call_wrap["raised"], 3)
+
+ def test_throw_not_retriable_error(self):
+ with self.assertRaises(ValueError):
+ for attempt in Retrying(TestPolicy(max_retries=2)):
+ with attempt:
+ raise ValueError
+
+ def test_specific_exception(self):
+ # Check that only specific exceptions are retried.
+ # Check that if we have less than 4 retries all is ok.
+ policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE)
+
+ for attempt in Retrying(policy):
+ with attempt:
+ self.stub(2, grpc.StatusCode.UNAVAILABLE)
+
+ self.assertLess(self.call_wrap["attempts"], 4)
+ self.assertEqual(self.call_wrap["raised"], 1)
+
+ def test_specific_exception_exceed_retries(self):
+ # Exceed the retries.
+ policy = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE)
+ with self.assertRaises(RetriesExceeded):
+ for attempt in Retrying(policy):
+ with attempt:
+ self.stub(5, grpc.StatusCode.UNAVAILABLE)
+
+ self.assertLess(self.call_wrap["attempts"], 4)
+ self.assertEqual(self.call_wrap["raised"], 3)
+
+ def test_rejected_by_policy(self):
+ # Test that another error is always thrown.
+ policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE)
+
+ with self.assertRaises(TestError):
+ for attempt in Retrying(policy):
+ with attempt:
+ self.stub(5, grpc.StatusCode.INTERNAL)
+
+ self.assertEqual(self.call_wrap["attempts"], 1)
+ self.assertEqual(self.call_wrap["raised"], 1)
+
+ def test_multiple_policies(self):
+ policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE)
+ policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL)
+
+ # Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors
+
+ error_suply = iter([grpc.StatusCode.UNAVAILABLE] * 2 + [grpc.StatusCode.INTERNAL] * 4)
+
+ for attempt in Retrying([policy1, policy2]):
+ with attempt:
+ error = next(error_suply, None)
+ if error:
+ raise TestError(error)
+
+ self.assertEqual(next(error_suply, None), None)
+
+ def test_multiple_policies_exceed(self):
+ policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.INTERNAL)
+ policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL)
+
+ with self.assertRaises(RetriesExceeded):
+ for attempt in Retrying([policy1, policy2]):
+ with attempt:
+ self.stub(10, grpc.StatusCode.INTERNAL)
+
+ self.assertEqual(self.call_wrap["attempts"], 7)
+ self.assertEqual(self.call_wrap["raised"], 7)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_connect_retry import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py
index 1fd59609d450c..1857796ac9aa0 100644
--- a/python/pyspark/sql/tests/connect/test_connect_session.py
+++ b/python/pyspark/sql/tests/connect/test_connect_session.py
@@ -18,28 +18,21 @@
import os
import unittest
import uuid
-from collections import defaultdict
from pyspark.util import is_remote_only
-from pyspark.errors import (
- PySparkException,
- PySparkValueError,
- RetriesExceeded,
-)
+from pyspark.errors import PySparkException
from pyspark.sql import SparkSession as PySparkSession
-
from pyspark.testing.connectutils import (
should_test_connect,
ReusedConnectTestCase,
connect_requirement_message,
)
+from pyspark.testing.utils import timeout
if should_test_connect:
import grpc
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
- from pyspark.sql.connect.client import DefaultChannelBuilder, ChannelBuilder
- from pyspark.sql.connect.client.core import Retrying, SparkConnectClient
- from pyspark.sql.connect.client.retries import RetryPolicy
+ from pyspark.sql.connect.client import ChannelBuilder
from pyspark.errors.exceptions.connect import (
AnalysisException,
SparkConnectException,
@@ -61,6 +54,7 @@ def setUp(self) -> None:
def tearDown(self):
self.spark.stop()
+ @timeout(10)
def test_progress_handler(self):
handler_called = []
@@ -85,6 +79,7 @@ def handler(**kwargs):
def _check_no_active_session_error(self, e: PySparkException):
self.check_error(exception=e, errorClass="NO_ACTIVE_SESSION", messageParameters=dict())
+ @timeout(10)
def test_stop_session(self):
df = self.spark.sql("select 1 as a, 2 as b")
catalog = self.spark.catalog
@@ -232,6 +227,7 @@ def test_get_message_parameters_without_enriched_error(self):
self.assertIsNotNone(exception)
self.assertEqual(exception.getMessageParameters(), {"objectName": "`a`"})
+ @timeout(10)
def test_custom_channel_builder(self):
# Access self.spark's DefaultChannelBuilder to reuse same endpoint
endpoint = self.spark._client._builder.endpoint
@@ -329,261 +325,6 @@ def test_config(self):
self.assertEqual(self.spark.conf.get("integer"), "1")
-if should_test_connect:
-
- class TestError(grpc.RpcError, Exception):
- def __init__(self, code: grpc.StatusCode):
- self._code = code
-
- def code(self):
- return self._code
-
- class TestPolicy(RetryPolicy):
- # Put a small value for initial backoff so that tests don't spend
- # Time waiting
- def __init__(self, initial_backoff=10, **kwargs):
- super().__init__(initial_backoff=initial_backoff, **kwargs)
-
- def can_retry(self, exception: BaseException):
- return isinstance(exception, TestError)
-
- class TestPolicySpecificError(TestPolicy):
- def __init__(self, specific_code: grpc.StatusCode, **kwargs):
- super().__init__(**kwargs)
- self.specific_code = specific_code
-
- def can_retry(self, exception: BaseException):
- return exception.code() == self.specific_code
-
-
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class RetryTests(unittest.TestCase):
- def setUp(self) -> None:
- self.call_wrap = defaultdict(int)
-
- def stub(self, retries, code):
- self.call_wrap["attempts"] += 1
- if self.call_wrap["attempts"] < retries:
- self.call_wrap["raised"] += 1
- raise TestError(code)
-
- def test_simple(self):
- # Check that max_retries 1 is only one retry so two attempts.
- for attempt in Retrying(TestPolicy(max_retries=1)):
- with attempt:
- self.stub(2, grpc.StatusCode.INTERNAL)
-
- self.assertEqual(2, self.call_wrap["attempts"])
- self.assertEqual(1, self.call_wrap["raised"])
-
- def test_below_limit(self):
- # Check that if we have less than 4 retries all is ok.
- for attempt in Retrying(TestPolicy(max_retries=4)):
- with attempt:
- self.stub(2, grpc.StatusCode.INTERNAL)
-
- self.assertLess(self.call_wrap["attempts"], 4)
- self.assertEqual(self.call_wrap["raised"], 1)
-
- def test_exceed_retries(self):
- # Exceed the retries.
- with self.assertRaises(RetriesExceeded):
- for attempt in Retrying(TestPolicy(max_retries=2)):
- with attempt:
- self.stub(5, grpc.StatusCode.INTERNAL)
-
- self.assertLess(self.call_wrap["attempts"], 5)
- self.assertEqual(self.call_wrap["raised"], 3)
-
- def test_throw_not_retriable_error(self):
- with self.assertRaises(ValueError):
- for attempt in Retrying(TestPolicy(max_retries=2)):
- with attempt:
- raise ValueError
-
- def test_specific_exception(self):
- # Check that only specific exceptions are retried.
- # Check that if we have less than 4 retries all is ok.
- policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE)
-
- for attempt in Retrying(policy):
- with attempt:
- self.stub(2, grpc.StatusCode.UNAVAILABLE)
-
- self.assertLess(self.call_wrap["attempts"], 4)
- self.assertEqual(self.call_wrap["raised"], 1)
-
- def test_specific_exception_exceed_retries(self):
- # Exceed the retries.
- policy = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE)
- with self.assertRaises(RetriesExceeded):
- for attempt in Retrying(policy):
- with attempt:
- self.stub(5, grpc.StatusCode.UNAVAILABLE)
-
- self.assertLess(self.call_wrap["attempts"], 4)
- self.assertEqual(self.call_wrap["raised"], 3)
-
- def test_rejected_by_policy(self):
- # Test that another error is always thrown.
- policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE)
-
- with self.assertRaises(TestError):
- for attempt in Retrying(policy):
- with attempt:
- self.stub(5, grpc.StatusCode.INTERNAL)
-
- self.assertEqual(self.call_wrap["attempts"], 1)
- self.assertEqual(self.call_wrap["raised"], 1)
-
- def test_multiple_policies(self):
- policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE)
- policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL)
-
- # Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors
-
- error_suply = iter([grpc.StatusCode.UNAVAILABLE] * 2 + [grpc.StatusCode.INTERNAL] * 4)
-
- for attempt in Retrying([policy1, policy2]):
- with attempt:
- error = next(error_suply, None)
- if error:
- raise TestError(error)
-
- self.assertEqual(next(error_suply, None), None)
-
- def test_multiple_policies_exceed(self):
- policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.INTERNAL)
- policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL)
-
- with self.assertRaises(RetriesExceeded):
- for attempt in Retrying([policy1, policy2]):
- with attempt:
- self.stub(10, grpc.StatusCode.INTERNAL)
-
- self.assertEqual(self.call_wrap["attempts"], 7)
- self.assertEqual(self.call_wrap["raised"], 7)
-
-
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class ChannelBuilderTests(unittest.TestCase):
- def test_invalid_connection_strings(self):
- invalid = [
- "scc://host:12",
- "http://host",
- "sc:/host:1234/path",
- "sc://host/path",
- "sc://host/;parm1;param2",
- ]
- for i in invalid:
- self.assertRaises(PySparkValueError, DefaultChannelBuilder, i)
-
- def test_sensible_defaults(self):
- chan = DefaultChannelBuilder("sc://host")
- self.assertFalse(chan.secure, "Default URL is not secure")
-
- chan = DefaultChannelBuilder("sc://host/;token=abcs")
- self.assertTrue(chan.secure, "specifying a token must set the channel to secure")
- self.assertRegex(
- chan.userAgent, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$"
- )
- chan = DefaultChannelBuilder("sc://host/;use_ssl=abcs")
- self.assertFalse(chan.secure, "Garbage in, false out")
-
- def test_user_agent(self):
- chan = DefaultChannelBuilder("sc://host/;user_agent=Agent123%20%2F3.4")
- self.assertIn("Agent123 /3.4", chan.userAgent)
-
- def test_user_agent_len(self):
- user_agent = "x" * 2049
- chan = DefaultChannelBuilder(f"sc://host/;user_agent={user_agent}")
- with self.assertRaises(SparkConnectException) as err:
- chan.userAgent
- self.assertRegex(err.exception._message, "'user_agent' parameter should not exceed")
-
- user_agent = "%C3%A4" * 341 # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048
- expected = "ä" * 341
- chan = DefaultChannelBuilder(f"sc://host/;user_agent={user_agent}")
- self.assertIn(expected, chan.userAgent)
-
- def test_valid_channel_creation(self):
- chan = DefaultChannelBuilder("sc://host").toChannel()
- self.assertIsInstance(chan, grpc.Channel)
-
- # Sets up a channel without tokens because ssl is not used.
- chan = DefaultChannelBuilder("sc://host/;use_ssl=true;token=abc").toChannel()
- self.assertIsInstance(chan, grpc.Channel)
-
- chan = DefaultChannelBuilder("sc://host/;use_ssl=true").toChannel()
- self.assertIsInstance(chan, grpc.Channel)
-
- def test_channel_properties(self):
- chan = DefaultChannelBuilder(
- "sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021"
- )
- self.assertEqual("host:15002", chan.endpoint)
- self.assertIn("foo", chan.userAgent.split(" "))
- self.assertEqual(True, chan.secure)
- self.assertEqual("120 21", chan.get("param1"))
-
- def test_metadata(self):
- chan = DefaultChannelBuilder(
- "sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd"
- )
- md = chan.metadata()
- self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md)
-
- def test_metadata_with_session_id(self):
- id = str(uuid.uuid4())
- chan = DefaultChannelBuilder(f"sc://host/;session_id={id}")
- self.assertEqual(id, chan.session_id)
-
- chan = DefaultChannelBuilder(
- f"sc://host/;session_id={id};user_agent=acbd;token=abcd;use_ssl=true"
- )
- md = chan.metadata()
- for kv in md:
- self.assertNotIn(
- kv[0],
- [
- ChannelBuilder.PARAM_SESSION_ID,
- ChannelBuilder.PARAM_TOKEN,
- ChannelBuilder.PARAM_USER_ID,
- ChannelBuilder.PARAM_USER_AGENT,
- ChannelBuilder.PARAM_USE_SSL,
- ],
- "Metadata must not contain fixed params",
- )
-
- with self.assertRaises(ValueError) as ve:
- chan = DefaultChannelBuilder("sc://host/;session_id=abcd")
- SparkConnectClient(chan)
- self.assertIn("Parameter value session_id must be a valid UUID format", str(ve.exception))
-
- chan = DefaultChannelBuilder("sc://host/")
- self.assertIsNone(chan.session_id)
-
- def test_channel_options(self):
- # SPARK-47694
- chan = DefaultChannelBuilder(
- "sc://host", [("grpc.max_send_message_length", 1860), ("test", "robert")]
- )
- options = chan._channel_options
- self.assertEqual(
- [k for k, _ in options].count("grpc.max_send_message_length"),
- 1,
- "only one occurrence for defaults",
- )
- self.assertEqual(
- next(v for k, v in options if k == "grpc.max_send_message_length"),
- 1860,
- "overwrites defaults",
- )
- self.assertEqual(
- next(v for k, v in options if k == "test"), "robert", "new values are picked up"
- )
-
-
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_session import * # noqa: F401
diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py
index 40b6a072e9127..44ff85e2f9a9b 100644
--- a/python/pyspark/sql/tests/connect/test_df_debug.py
+++ b/python/pyspark/sql/tests/connect/test_df_debug.py
@@ -17,17 +17,13 @@
import unittest
-from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
-from pyspark.testing.connectutils import should_test_connect
+from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.utils import have_graphviz, graphviz_requirement_message
-if should_test_connect:
- from pyspark.sql.connect.dataframe import DataFrame
-
-class SparkConnectDataFrameDebug(SparkConnectSQLTestCase):
+class SparkConnectDataFrameDebug(ReusedConnectTestCase):
def test_df_debug_basics(self):
- df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
+ df = self.spark.range(100).repartition(10).groupBy("id").count()
x = df.collect() # noqa: F841
ei = df.executionInfo
@@ -35,12 +31,12 @@ def test_df_debug_basics(self):
self.assertIn(root, graph, "The root must be rooted in the graph")
def test_df_quey_execution_empty_before_execution(self):
- df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
+ df = self.spark.range(100).repartition(10).groupBy("id").count()
ei = df.executionInfo
self.assertIsNone(ei, "The query execution must be None before the action is executed")
def test_df_query_execution_with_writes(self):
- df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
+ df = self.spark.range(100).repartition(10).groupBy("id").count()
df.write.save("/tmp/test_df_query_execution_with_writes", format="json", mode="overwrite")
ei = df.executionInfo
self.assertIsNotNone(
@@ -48,18 +44,18 @@ def test_df_query_execution_with_writes(self):
)
def test_query_execution_text_format(self):
- df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
+ df = self.spark.range(100).repartition(10).groupBy("id").count()
df.collect()
self.assertIn("HashAggregate", df.executionInfo.metrics.toText())
# Different execution mode.
- df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
+ df = self.spark.range(100).repartition(10).groupBy("id").count()
df.toPandas()
self.assertIn("HashAggregate", df.executionInfo.metrics.toText())
@unittest.skipIf(not have_graphviz, graphviz_requirement_message)
def test_df_query_execution_metrics_to_dot(self):
- df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count()
+ df = self.spark.range(100).repartition(10).groupBy("id").count()
x = df.collect() # noqa: F841
ei = df.executionInfo
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py
index 0a77c5531082a..d23df4527133e 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -34,6 +34,10 @@ def test_function_parity(self):
def test_input_file_name_reset_for_rdd(self):
super().test_input_file_name_reset_for_rdd()
+ @unittest.skip("No need to test in Spark Connect.")
+ def test_wildcard_import(self):
+ super().test_wildcard_import()
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_functions import * # noqa: F401
diff --git a/python/pyspark/sql/tests/pandas/helper/__init__.py b/python/pyspark/sql/tests/pandas/helper/__init__.py
new file mode 100644
index 0000000000000..cce3acad34a49
--- /dev/null
+++ b/python/pyspark/sql/tests/pandas/helper/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
new file mode 100644
index 0000000000000..cc9f29609a5b0
--- /dev/null
+++ b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
@@ -0,0 +1,1615 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import abstractmethod
+import sys
+from typing import Iterator
+import unittest
+from pyspark.errors import PySparkRuntimeError
+from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
+from pyspark.sql.types import (
+ StringType,
+ StructType,
+ StructField,
+ Row,
+ IntegerType,
+ TimestampType,
+ LongType,
+ BooleanType,
+ FloatType,
+)
+from pyspark.testing.sqlutils import have_pandas
+
+if have_pandas:
+ import pandas as pd
+
+
+class StatefulProcessorFactory:
+ @abstractmethod
+ def pandas(self):
+ ...
+
+ @abstractmethod
+ def row(self):
+ ...
+
+
+# StatefulProcessor factory implementations
+
+
+class SimpleStatefulProcessorWithInitialStateFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasSimpleStatefulProcessorWithInitialState()
+
+ def row(self):
+ return RowSimpleStatefulProcessorWithInitialState()
+
+
+class StatefulProcessorWithInitialStateTimersFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasStatefulProcessorWithInitialStateTimers()
+
+ def row(self):
+ return RowStatefulProcessorWithInitialStateTimers()
+
+
+class StatefulProcessorWithListStateInitialStateFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasStatefulProcessorWithListStateInitialState()
+
+ def row(self):
+ return RowStatefulProcessorWithListStateInitialState()
+
+
+class EventTimeStatefulProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasEventTimeStatefulProcessor()
+
+ def row(self):
+ return RowEventTimeStatefulProcessor()
+
+
+class ProcTimeStatefulProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasProcTimeStatefulProcessor()
+
+ def row(self):
+ return RowProcTimeStatefulProcessor()
+
+
+class SimpleStatefulProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasSimpleStatefulProcessor()
+
+ def row(self):
+ return RowSimpleStatefulProcessor()
+
+
+class StatefulProcessorChainingOpsFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasStatefulProcessorChainingOps()
+
+ def row(self):
+ return RowStatefulProcessorChainingOps()
+
+
+class SimpleTTLStatefulProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasSimpleTTLStatefulProcessor()
+
+ def row(self):
+ return RowSimpleTTLStatefulProcessor()
+
+
+class TTLStatefulProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasTTLStatefulProcessor()
+
+ def row(self):
+ return RowTTLStatefulProcessor()
+
+
+class InvalidSimpleStatefulProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasInvalidSimpleStatefulProcessor()
+
+ def row(self):
+ return RowInvalidSimpleStatefulProcessor()
+
+
+class ListStateProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasListStateProcessor()
+
+ def row(self):
+ return RowListStateProcessor()
+
+
+class ListStateLargeListProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasListStateLargeListProcessor()
+
+ def row(self):
+ return RowListStateLargeListProcessor()
+
+
+class ListStateLargeTTLProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasListStateLargeTTLProcessor()
+
+ def row(self):
+ return RowListStateLargeTTLProcessor()
+
+
+class MapStateProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasMapStateProcessor()
+
+ def row(self):
+ return RowMapStateProcessor()
+
+
+class MapStateLargeTTLProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasMapStateLargeTTLProcessor()
+
+ def row(self):
+ return RowMapStateLargeTTLProcessor()
+
+
+class BasicProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasBasicProcessor()
+
+ def row(self):
+ return RowBasicProcessor()
+
+
+class BasicProcessorNotNullableFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasBasicProcessorNotNullable()
+
+ def row(self):
+ return RowBasicProcessorNotNullable()
+
+
+class AddFieldsProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasAddFieldsProcessor()
+
+ def row(self):
+ return RowAddFieldsProcessor()
+
+
+class RemoveFieldsProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasRemoveFieldsProcessor()
+
+ def row(self):
+ return RowRemoveFieldsProcessor()
+
+
+class ReorderedFieldsProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasReorderedFieldsProcessor()
+
+ def row(self):
+ return RowReorderedFieldsProcessor()
+
+
+class UpcastProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasUpcastProcessor()
+
+ def row(self):
+ return RowUpcastProcessor()
+
+
+class MinEventTimeStatefulProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasMinEventTimeStatefulProcessor()
+
+ def row(self):
+ return RowMinEventTimeStatefulProcessor()
+
+
+# StatefulProcessor implementations
+
+
+class PandasSimpleStatefulProcessorWithInitialState(StatefulProcessor):
+ # this dict is the same as input initial state dataframe
+ dict = {("0",): 789, ("3",): 987}
+
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.value_state = handle.getValueState("value_state", state_schema)
+ self.handle = handle
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ exists = self.value_state.exists()
+ if exists:
+ value_row = self.value_state.get()
+ existing_value = value_row[0]
+ else:
+ existing_value = 0
+
+ accumulated_value = existing_value
+
+ for pdf in rows:
+ value = pdf["temperature"].astype(int).sum()
+ accumulated_value += value
+
+ self.value_state.update((accumulated_value,))
+
+ if len(key) > 1:
+ yield pd.DataFrame(
+ {"id1": (key[0],), "id2": (key[1],), "value": str(accumulated_value)}
+ )
+ else:
+ yield pd.DataFrame({"id": key, "value": str(accumulated_value)})
+
+ def handleInitialState(self, key, initialState, timerValues) -> None:
+ init_val = initialState.at[0, "initVal"]
+ self.value_state.update((init_val,))
+ if len(key) == 1:
+ assert self.dict[key] == init_val
+
+ def close(self) -> None:
+ pass
+
+
+class RowSimpleStatefulProcessorWithInitialState(StatefulProcessor):
+ # this dict is the same as input initial state dataframe
+ dict = {("0",): 789, ("3",): 987}
+
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.value_state = handle.getValueState("value_state", state_schema)
+ self.handle = handle
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ exists = self.value_state.exists()
+ if exists:
+ value_row = self.value_state.get()
+ existing_value = value_row[0]
+ else:
+ existing_value = 0
+
+ accumulated_value = existing_value
+
+ for row in rows:
+ value = row.temperature
+ accumulated_value += value
+
+ self.value_state.update((accumulated_value,))
+
+ if len(key) > 1:
+ yield Row(id1=key[0], id2=key[1], value=str(accumulated_value))
+ else:
+ yield Row(id=key[0], value=str(accumulated_value))
+
+ def handleInitialState(self, key, initialState, timerValues) -> None:
+ init_val = initialState.initVal
+ self.value_state.update((init_val,))
+ if len(key) == 1:
+ assert self.dict[key] == init_val
+
+ def close(self) -> None:
+ pass
+
+
+class PandasStatefulProcessorWithInitialStateTimers(PandasSimpleStatefulProcessorWithInitialState):
+ def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
+ self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
+ str_key = f"{str(key[0])}-expired"
+ yield pd.DataFrame({"id": (str_key,), "value": str(expiredTimerInfo.getExpiryTimeInMs())})
+
+ def handleInitialState(self, key, initialState, timerValues) -> None:
+ super().handleInitialState(key, initialState, timerValues)
+ self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() - 1)
+
+
+class RowStatefulProcessorWithInitialStateTimers(RowSimpleStatefulProcessorWithInitialState):
+ def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[Row]:
+ self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
+ str_key = f"{str(key[0])}-expired"
+ yield Row(id=str_key, value=str(expiredTimerInfo.getExpiryTimeInMs()))
+
+ def handleInitialState(self, key, initialState, timerValues) -> None:
+ super().handleInitialState(key, initialState, timerValues)
+ self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() - 1)
+
+
+class PandasStatefulProcessorWithListStateInitialState(
+ PandasSimpleStatefulProcessorWithInitialState
+):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ super().init(handle)
+ list_ele_schema = StructType([StructField("value", IntegerType(), True)])
+ self.list_state = handle.getListState("list_state", list_ele_schema)
+
+ def handleInitialState(self, key, initialState, timerValues) -> None:
+ for val in initialState["initVal"].tolist():
+ self.list_state.appendValue((val,))
+
+
+class RowStatefulProcessorWithListStateInitialState(RowSimpleStatefulProcessorWithInitialState):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ super().init(handle)
+ list_ele_schema = StructType([StructField("value", IntegerType(), True)])
+ self.list_state = handle.getListState("list_state", list_ele_schema)
+
+ def handleInitialState(self, key, initialState, timerValues) -> None:
+ self.list_state.appendValue((initialState.initVal,))
+
+
+# A stateful processor that output the max event time it has seen. Register timer for
+# current watermark. Clear max state if timer expires.
+class PandasEventTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.max_state = handle.getValueState("max_state", state_schema)
+
+ def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
+ self.max_state.clear()
+ self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
+ str_key = f"{str(key[0])}-expired"
+ yield pd.DataFrame(
+ {"id": (str_key,), "timestamp": str(expiredTimerInfo.getExpiryTimeInMs())}
+ )
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ timestamp_list = []
+ for pdf in rows:
+ # int64 will represent timestamp in nanosecond, restore to second
+ timestamp_list.extend((pdf["eventTime"].astype("int64") // 10**9).tolist())
+
+ if self.max_state.exists():
+ cur_max = int(self.max_state.get()[0])
+ else:
+ cur_max = 0
+ max_event_time = str(max(cur_max, max(timestamp_list)))
+
+ self.max_state.update((max_event_time,))
+ self.handle.registerTimer(timerValues.getCurrentWatermarkInMs())
+
+ yield pd.DataFrame({"id": key, "timestamp": max_event_time})
+
+ def close(self) -> None:
+ pass
+
+
+# A stateful processor that output the max event time it has seen. Register timer for
+# current watermark. Clear max state if timer expires.
+class RowEventTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.max_state = handle.getValueState("max_state", state_schema)
+
+ def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[Row]:
+ self.max_state.clear()
+ self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
+ str_key = f"{str(key[0])}-expired"
+ yield Row(id=str_key, timestamp=str(expiredTimerInfo.getExpiryTimeInMs()))
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ timestamp_list = []
+ for row in rows:
+ # timestamp is microsecond, restore to second
+ timestamp_list.append(int(row.eventTime.timestamp()))
+
+ if self.max_state.exists():
+ cur_max = int(self.max_state.get()[0])
+ else:
+ cur_max = 0
+ max_event_time = str(max(cur_max, max(timestamp_list)))
+
+ self.max_state.update((max_event_time,))
+ self.handle.registerTimer(timerValues.getCurrentWatermarkInMs())
+
+ yield Row(id=key[0], timestamp=max_event_time)
+
+ def close(self) -> None:
+ pass
+
+
+# A stateful processor that output the accumulation of count of input rows; register
+# processing timer and clear the counter if timer expires.
+class PandasProcTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.count_state = handle.getValueState("count_state", state_schema)
+
+ def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
+ # reset count state each time the timer is expired
+ timer_list_1 = [e for e in self.handle.listTimers()]
+ timer_list_2 = []
+ idx = 0
+ for e in self.handle.listTimers():
+ timer_list_2.append(e)
+ # check multiple iterator on the same grouping key works
+ assert timer_list_2[idx] == timer_list_1[idx]
+ idx += 1
+
+ if len(timer_list_1) > 0:
+ assert len(timer_list_1) == 2
+ self.count_state.clear()
+ self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
+ yield pd.DataFrame(
+ {
+ "id": key,
+ "countAsString": str("-1"),
+ "timeValues": str(expiredTimerInfo.getExpiryTimeInMs()),
+ }
+ )
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ if not self.count_state.exists():
+ count = 0
+ else:
+ count = int(self.count_state.get()[0])
+
+ if key == ("0",):
+ self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 1)
+
+ rows_count = 0
+ for pdf in rows:
+ pdf_count = len(pdf)
+ rows_count += pdf_count
+
+ count = count + rows_count
+
+ self.count_state.update((str(count),))
+ timestamp = str(timerValues.getCurrentProcessingTimeInMs())
+
+ yield pd.DataFrame({"id": key, "countAsString": str(count), "timeValues": timestamp})
+
+ def close(self) -> None:
+ pass
+
+
+# A stateful processor that output the accumulation of count of input rows; register
+# processing timer and clear the counter if timer expires.
+class RowProcTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.count_state = handle.getValueState("count_state", state_schema)
+
+ def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[Row]:
+ # reset count state each time the timer is expired
+ timer_list_1 = [e for e in self.handle.listTimers()]
+ timer_list_2 = []
+ idx = 0
+ for e in self.handle.listTimers():
+ timer_list_2.append(e)
+ # check multiple iterator on the same grouping key works
+ assert timer_list_2[idx] == timer_list_1[idx]
+ idx += 1
+
+ if len(timer_list_1) > 0:
+ assert len(timer_list_1) == 2
+ self.count_state.clear()
+ self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
+ yield Row(
+ id=key[0], countAsString=str(-1), timeValues=str(expiredTimerInfo.getExpiryTimeInMs())
+ )
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ if not self.count_state.exists():
+ count = 0
+ else:
+ count = int(self.count_state.get()[0])
+
+ if key == ("0",):
+ self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 1)
+
+ rows_count = 0
+ for row in rows:
+ rows_count += 1
+
+ count = count + rows_count
+
+ self.count_state.update((str(count),))
+ timestamp = str(timerValues.getCurrentProcessingTimeInMs())
+
+ yield Row(id=key[0], countAsString=str(count), timeValues=timestamp)
+
+ def close(self) -> None:
+ pass
+
+
+class PandasSimpleStatefulProcessor(StatefulProcessor, unittest.TestCase):
+ dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
+ batch_id = 0
+
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ # Test both string type and struct type schemas
+ self.num_violations_state = handle.getValueState("numViolations", "value int")
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.temp_state = handle.getValueState("tempState", state_schema)
+ handle.deleteIfExists("tempState")
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"):
+ self.temp_state.exists()
+ new_violations = 0
+ count = 0
+ key_str = key[0]
+ exists = self.num_violations_state.exists()
+ if exists:
+ existing_violations_row = self.num_violations_state.get()
+ existing_violations = existing_violations_row[0]
+ assert existing_violations == self.dict[0][key_str]
+ self.batch_id = 1
+ else:
+ existing_violations = 0
+ for pdf in rows:
+ pdf_count = pdf.count()
+ count += pdf_count.get("temperature")
+ violations_pdf = pdf.loc[pdf["temperature"] > 100]
+ new_violations += violations_pdf.count().get("temperature")
+ updated_violations = new_violations + existing_violations
+ assert updated_violations == self.dict[self.batch_id][key_str]
+ self.num_violations_state.update((updated_violations,))
+ yield pd.DataFrame({"id": key, "countAsString": str(count)})
+
+ def close(self) -> None:
+ pass
+
+
+class RowSimpleStatefulProcessor(StatefulProcessor, unittest.TestCase):
+ dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
+ batch_id = 0
+
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ # Test both string type and struct type schemas
+ self.num_violations_state = handle.getValueState("numViolations", "value int")
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.temp_state = handle.getValueState("tempState", state_schema)
+ handle.deleteIfExists("tempState")
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"):
+ self.temp_state.exists()
+ new_violations = 0
+ count = 0
+ key_str = key[0]
+ exists = self.num_violations_state.exists()
+ if exists:
+ existing_violations_row = self.num_violations_state.get()
+ existing_violations = existing_violations_row[0]
+ assert existing_violations == self.dict[0][key_str]
+ self.batch_id = 1
+ else:
+ existing_violations = 0
+ for row in rows:
+ # temperature should be non-NA to be counted
+ temperature = row.temperature
+ if temperature is not None:
+ count += 1
+ if temperature > 100:
+ new_violations += 1
+ updated_violations = new_violations + existing_violations
+ assert updated_violations == self.dict[self.batch_id][key_str]
+ self.num_violations_state.update((updated_violations,))
+ yield Row(id=key[0], countAsString=str(count))
+
+ def close(self) -> None:
+ pass
+
+
+class PandasStatefulProcessorChainingOps(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ pass
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ for pdf in rows:
+ timestamp_list = pdf["eventTime"].tolist()
+ yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]})
+
+ def close(self) -> None:
+ pass
+
+
+class RowStatefulProcessorChainingOps(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ pass
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ timestamp_list = []
+ for row in rows:
+ timestamp_list.append(row.eventTime)
+ yield Row(id=key[0], outputTimestamp=timestamp_list[0])
+
+ def close(self) -> None:
+ pass
+
+
+# A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use
+# ttl state with a large timeout.
+class PandasSimpleTTLStatefulProcessor(PandasSimpleStatefulProcessor, unittest.TestCase):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.num_violations_state = handle.getValueState("numViolations", state_schema, 30000)
+ self.temp_state = handle.getValueState("tempState", state_schema)
+ handle.deleteIfExists("tempState")
+
+
+# A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use
+# ttl state with a large timeout.
+class RowSimpleTTLStatefulProcessor(RowSimpleStatefulProcessor, unittest.TestCase):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.num_violations_state = handle.getValueState("numViolations", state_schema, 30000)
+ self.temp_state = handle.getValueState("tempState", state_schema)
+ handle.deleteIfExists("tempState")
+
+
+class PandasTTLStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ user_key_schema = StructType([StructField("id", StringType(), True)])
+ self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000)
+ self.count_state = handle.getValueState("state", state_schema)
+ self.ttl_list_state = handle.getListState("ttl-list-state", state_schema, 10000)
+ self.ttl_map_state = handle.getMapState(
+ "ttl-map-state", user_key_schema, state_schema, 10000
+ )
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ count = 0
+ ttl_count = 0
+ ttl_list_state_count = 0
+ ttl_map_state_count = 0
+ id = key[0]
+ if self.count_state.exists():
+ count = self.count_state.get()[0]
+ if self.ttl_count_state.exists():
+ ttl_count = self.ttl_count_state.get()[0]
+ if self.ttl_list_state.exists():
+ iter = self.ttl_list_state.get()
+ for s in iter:
+ ttl_list_state_count += s[0]
+ if self.ttl_map_state.exists():
+ ttl_map_state_count = self.ttl_map_state.getValue(key)[0]
+ for pdf in rows:
+ pdf_count = pdf.count().get("temperature")
+ count += pdf_count
+ ttl_count += pdf_count
+ ttl_list_state_count += pdf_count
+ ttl_map_state_count += pdf_count
+
+ self.count_state.update((count,))
+ # skip updating state for the 2nd batch so that ttl state expire
+ if not (ttl_count == 2 and id == "0"):
+ self.ttl_count_state.update((ttl_count,))
+ self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)])
+ self.ttl_map_state.updateValue(key, (ttl_map_state_count,))
+ yield pd.DataFrame(
+ {
+ "id": [
+ f"ttl-count-{id}",
+ f"count-{id}",
+ f"ttl-list-state-count-{id}",
+ f"ttl-map-state-count-{id}",
+ ],
+ "count": [ttl_count, count, ttl_list_state_count, ttl_map_state_count],
+ }
+ )
+
+
+class RowTTLStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ user_key_schema = StructType([StructField("id", StringType(), True)])
+ self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000)
+ self.count_state = handle.getValueState("state", state_schema)
+ self.ttl_list_state = handle.getListState("ttl-list-state", state_schema, 10000)
+ self.ttl_map_state = handle.getMapState(
+ "ttl-map-state", user_key_schema, state_schema, 10000
+ )
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ count = 0
+ ttl_count = 0
+ ttl_list_state_count = 0
+ ttl_map_state_count = 0
+ id = key[0]
+ if self.count_state.exists():
+ count = self.count_state.get()[0]
+ if self.ttl_count_state.exists():
+ ttl_count = self.ttl_count_state.get()[0]
+ if self.ttl_list_state.exists():
+ iter = self.ttl_list_state.get()
+ for s in iter:
+ ttl_list_state_count += s[0]
+ if self.ttl_map_state.exists():
+ ttl_map_state_count = self.ttl_map_state.getValue(key)[0]
+ for row in rows:
+ if row.temperature is not None:
+ count += 1
+ ttl_count += 1
+ ttl_list_state_count += 1
+ ttl_map_state_count += 1
+
+ self.count_state.update((count,))
+ # skip updating state for the 2nd batch so that ttl state expire
+ if not (ttl_count == 2 and id == "0"):
+ self.ttl_count_state.update((ttl_count,))
+ self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)])
+ self.ttl_map_state.updateValue(key, (ttl_map_state_count,))
+
+ ret = [
+ Row(id=f"ttl-count-{id}", count=ttl_count),
+ Row(id=f"count-{id}", count=count),
+ Row(id=f"ttl-list-state-count-{id}", count=ttl_list_state_count),
+ Row(id=f"ttl-map-state-count-{id}", count=ttl_map_state_count),
+ ]
+ return iter(ret)
+
+
+class PandasInvalidSimpleStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.num_violations_state = handle.getValueState("numViolations", state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ count = 0
+ exists = self.num_violations_state.exists()
+ assert not exists
+ # try to get a state variable with no value
+ assert self.num_violations_state.get() is None
+ self.num_violations_state.clear()
+ yield pd.DataFrame({"id": key, "countAsString": str(count)})
+
+
+class RowInvalidSimpleStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.num_violations_state = handle.getValueState("numViolations", state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ count = 0
+ exists = self.num_violations_state.exists()
+ assert not exists
+ # try to get a state variable with no value
+ assert self.num_violations_state.get() is None
+ self.num_violations_state.clear()
+ yield Row(id=key[0], countAsString=str(count))
+
+
+class PandasListStateProcessor(StatefulProcessor):
+ # Dict to store the expected results. The key represents the grouping key string, and the value
+ # is a dictionary of pandas dataframe index -> expected temperature value. Since we set
+ # maxRecordsPerBatch to 2, we expect the pandas dataframe dictionary to have 2 entries.
+ dict = {0: 120, 1: 20}
+
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("temperature", IntegerType(), True)])
+ timestamp_schema = StructType([StructField("time", TimestampType(), True)])
+ self.list_state1 = handle.getListState("listState1", state_schema)
+ self.list_state2 = handle.getListState("listState2", state_schema)
+ self.list_state_timestamp = handle.getListState("listStateTimestamp", timestamp_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ import datetime
+
+ count = 0
+ time_list = []
+ for pdf in rows:
+ list_state_rows = [(120,), (20,)]
+ self.list_state1.put(list_state_rows)
+ self.list_state2.put(list_state_rows)
+ self.list_state1.appendValue((111,))
+ self.list_state2.appendValue((222,))
+ self.list_state1.appendList(list_state_rows)
+ self.list_state2.appendList(list_state_rows)
+ pdf_count = pdf.count()
+ count += pdf_count.get("temperature")
+ current_processing_time = datetime.datetime.fromtimestamp(
+ timerValues.getCurrentProcessingTimeInMs() / 1000
+ )
+ stored_time = current_processing_time + datetime.timedelta(minutes=1)
+ time_list.append((stored_time,))
+ iter1 = self.list_state1.get()
+ iter2 = self.list_state2.get()
+ # Mixing the iterator to test it we can resume from the correct point
+ assert next(iter1)[0] == self.dict[0]
+ assert next(iter2)[0] == self.dict[0]
+ assert next(iter1)[0] == self.dict[1]
+ assert next(iter2)[0] == self.dict[1]
+ # Get another iterator for list_state1 to test if the 2 iterators (iter1 and iter3) don't
+ # interfere with each other.
+ iter3 = self.list_state1.get()
+ assert next(iter3)[0] == self.dict[0]
+ assert next(iter3)[0] == self.dict[1]
+ # the second arrow batch should contain the appended value 111 for list_state1 and
+ # 222 for list_state2
+ assert next(iter1)[0] == 111
+ assert next(iter2)[0] == 222
+ assert next(iter3)[0] == 111
+ # since we put another 2 rows after 111/222, check them here
+ assert next(iter1)[0] == self.dict[0]
+ assert next(iter2)[0] == self.dict[0]
+ assert next(iter3)[0] == self.dict[0]
+ assert next(iter1)[0] == self.dict[1]
+ assert next(iter2)[0] == self.dict[1]
+ assert next(iter3)[0] == self.dict[1]
+ if time_list:
+ # Validate timestamp type can work properly with arrow transmission
+ self.list_state_timestamp.put(time_list)
+ yield pd.DataFrame({"id": key, "countAsString": str(count)})
+
+ def close(self) -> None:
+ pass
+
+
+class RowListStateProcessor(StatefulProcessor):
+ # Dict to store the expected results. The key represents the grouping key string, and the value
+ # is a dictionary of pandas dataframe index -> expected temperature value. Since we set
+ # maxRecordsPerBatch to 2, we expect the pandas dataframe dictionary to have 2 entries.
+ dict = {0: 120, 1: 20}
+
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("temperature", IntegerType(), True)])
+ timestamp_schema = StructType([StructField("time", TimestampType(), True)])
+ self.list_state1 = handle.getListState("listState1", state_schema)
+ self.list_state2 = handle.getListState("listState2", state_schema)
+ self.list_state_timestamp = handle.getListState("listStateTimestamp", timestamp_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ import datetime
+
+ count = 0
+ time_list = []
+ for row in rows:
+ list_state_rows = [(120,), (20,)]
+ self.list_state1.put(list_state_rows)
+ self.list_state2.put(list_state_rows)
+ self.list_state1.appendValue((111,))
+ self.list_state2.appendValue((222,))
+ self.list_state1.appendList(list_state_rows)
+ self.list_state2.appendList(list_state_rows)
+
+ if row.temperature is not None:
+ count += 1
+
+ current_processing_time = datetime.datetime.fromtimestamp(
+ timerValues.getCurrentProcessingTimeInMs() / 1000
+ )
+ stored_time = current_processing_time + datetime.timedelta(minutes=1)
+ time_list.append((stored_time,))
+ iter1 = self.list_state1.get()
+ iter2 = self.list_state2.get()
+ # Mixing the iterator to test it we can resume from the correct point
+ assert next(iter1)[0] == self.dict[0]
+ assert next(iter2)[0] == self.dict[0]
+ assert next(iter1)[0] == self.dict[1]
+ assert next(iter2)[0] == self.dict[1]
+ # Get another iterator for list_state1 to test if the 2 iterators (iter1 and iter3) don't
+ # interfere with each other.
+ iter3 = self.list_state1.get()
+ assert next(iter3)[0] == self.dict[0]
+ assert next(iter3)[0] == self.dict[1]
+ # the second arrow batch should contain the appended value 111 for list_state1 and
+ # 222 for list_state2
+ assert next(iter1)[0] == 111
+ assert next(iter2)[0] == 222
+ assert next(iter3)[0] == 111
+ # since we put another 2 rows after 111/222, check them here
+ assert next(iter1)[0] == self.dict[0]
+ assert next(iter2)[0] == self.dict[0]
+ assert next(iter3)[0] == self.dict[0]
+ assert next(iter1)[0] == self.dict[1]
+ assert next(iter2)[0] == self.dict[1]
+ assert next(iter3)[0] == self.dict[1]
+ if time_list:
+ # Validate timestamp type can work properly with arrow transmission
+ self.list_state_timestamp.put(time_list)
+ yield Row(id=key[0], countAsString=str(count))
+
+ def close(self) -> None:
+ pass
+
+
+class PandasListStateLargeListProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ list_state_schema = StructType([StructField("value", IntegerType(), True)])
+ value_state_schema = StructType([StructField("size", IntegerType(), True)])
+ self.list_state = handle.getListState("listState", list_state_schema)
+ self.list_size_state = handle.getValueState("listSizeState", value_state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ elements_iter = self.list_state.get()
+ elements = list(elements_iter)
+
+ # Use the magic number 100 to test with both inline proto case and Arrow case.
+ # TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
+ # value backed by various benchmarks.
+ # Put 90 elements per batch:
+ # 1st batch: read 0 element, and write 90 elements, read back 90 elements
+ # (both use inline proto)
+ # 2nd batch: read 90 elements, and write 90 elements, read back 180 elements
+ # (read uses both inline proto and Arrow, write uses Arrow)
+
+ if len(elements) == 0:
+ # should be the first batch
+ assert self.list_size_state.get() is None
+ new_elements = [(i,) for i in range(90)]
+ if key == ("0",):
+ self.list_state.put(new_elements)
+ else:
+ self.list_state.appendList(new_elements)
+ self.list_size_state.update((len(new_elements),))
+ else:
+ # check the elements
+ list_size = self.list_size_state.get()
+ assert list_size is not None
+ list_size = list_size[0]
+ assert list_size == len(
+ elements
+ ), f"list_size ({list_size}) != len(elements) ({len(elements)})"
+
+ expected_elements_in_state = [(i,) for i in range(list_size)]
+ assert elements == expected_elements_in_state
+
+ if key == ("0",):
+ # Use the operation `put`
+ new_elements = [(i,) for i in range(list_size + 90)]
+ self.list_state.put(new_elements)
+ final_size = len(new_elements)
+ self.list_size_state.update((final_size,))
+ else:
+ # Use the operation `appendList`
+ new_elements = [(i,) for i in range(list_size, list_size + 90)]
+ self.list_state.appendList(new_elements)
+ final_size = len(new_elements) + list_size
+ self.list_size_state.update((final_size,))
+
+ prev_elements = ",".join(map(lambda x: str(x[0]), elements))
+ updated_elements = ",".join(map(lambda x: str(x[0]), self.list_state.get()))
+
+ yield pd.DataFrame(
+ {"id": key, "prevElements": prev_elements, "updatedElements": updated_elements}
+ )
+
+
+class RowListStateLargeListProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ list_state_schema = StructType([StructField("value", IntegerType(), True)])
+ value_state_schema = StructType([StructField("size", IntegerType(), True)])
+ self.list_state = handle.getListState("listState", list_state_schema)
+ self.list_size_state = handle.getValueState("listSizeState", value_state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ elements_iter = self.list_state.get()
+
+ elements = list(elements_iter)
+
+ # Use the magic number 100 to test with both inline proto case and Arrow case.
+ # TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
+ # value backed by various benchmarks.
+ # Put 90 elements per batch:
+ # 1st batch: read 0 element, and write 90 elements, read back 90 elements
+ # (both use inline proto)
+ # 2nd batch: read 90 elements, and write 90 elements, read back 180 elements
+ # (read uses both inline proto and Arrow, write uses Arrow)
+
+ if len(elements) == 0:
+ # should be the first batch
+ assert self.list_size_state.get() is None
+ new_elements = [(i,) for i in range(90)]
+ if key == ("0",):
+ self.list_state.put(new_elements)
+ else:
+ self.list_state.appendList(new_elements)
+ self.list_size_state.update((len(new_elements),))
+ else:
+ # check the elements
+ list_size = self.list_size_state.get()
+ assert list_size is not None
+ list_size = list_size[0]
+ assert list_size == len(
+ elements
+ ), f"list_size ({list_size}) != len(elements) ({len(elements)})"
+
+ expected_elements_in_state = [(i,) for i in range(list_size)]
+ assert elements == expected_elements_in_state
+
+ if key == ("0",):
+ # Use the operation `put`
+ new_elements = [(i,) for i in range(list_size + 90)]
+ self.list_state.put(new_elements)
+ final_size = len(new_elements)
+ self.list_size_state.update((final_size,))
+ else:
+ # Use the operation `appendList`
+ new_elements = [(i,) for i in range(list_size, list_size + 90)]
+ self.list_state.appendList(new_elements)
+ final_size = len(new_elements) + list_size
+ self.list_size_state.update((final_size,))
+
+ prev_elements = ",".join(map(lambda x: str(x[0]), elements))
+ updated_elements = ",".join(map(lambda x: str(x[0]), self.list_state.get()))
+
+ yield Row(id=key[0], prevElements=prev_elements, updatedElements=updated_elements)
+
+
+class PandasListStateLargeTTLProcessor(PandasListStateProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("temperature", IntegerType(), True)])
+ timestamp_schema = StructType([StructField("time", TimestampType(), True)])
+ self.list_state1 = handle.getListState("listState1", state_schema, 30000)
+ self.list_state2 = handle.getListState("listState2", state_schema, 30000)
+ self.list_state_timestamp = handle.getListState("listStateTimestamp", timestamp_schema)
+
+
+class RowListStateLargeTTLProcessor(RowListStateProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("temperature", IntegerType(), True)])
+ timestamp_schema = StructType([StructField("time", TimestampType(), True)])
+ self.list_state1 = handle.getListState("listState1", state_schema, 30000)
+ self.list_state2 = handle.getListState("listState2", state_schema, 30000)
+ self.list_state_timestamp = handle.getListState("listStateTimestamp", timestamp_schema)
+
+
+class PandasMapStateProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle):
+ # Test string type schemas
+ self.map_state = handle.getMapState("mapState", "name string", "count int")
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ count = 0
+ key1 = ("key1",)
+ key2 = ("key2",)
+ for pdf in rows:
+ pdf_count = pdf.count()
+ count += pdf_count.get("temperature")
+ value1 = count
+ value2 = count
+ if self.map_state.exists():
+ if self.map_state.containsKey(key1):
+ value1 += self.map_state.getValue(key1)[0]
+ if self.map_state.containsKey(key2):
+ value2 += self.map_state.getValue(key2)[0]
+ self.map_state.updateValue(key1, (value1,))
+ self.map_state.updateValue(key2, (value2,))
+ key_iter = self.map_state.keys()
+ assert next(key_iter)[0] == "key1"
+ assert next(key_iter)[0] == "key2"
+ value_iter = self.map_state.values()
+ assert next(value_iter)[0] == value1
+ assert next(value_iter)[0] == value2
+ map_iter = self.map_state.iterator()
+ assert next(map_iter)[0] == key1
+ assert next(map_iter)[1] == (value2,)
+ self.map_state.removeKey(key1)
+ assert not self.map_state.containsKey(key1)
+ assert self.map_state.exists()
+ self.map_state.clear()
+ assert not self.map_state.exists()
+ yield pd.DataFrame({"id": key, "countAsString": str(count)})
+
+ def close(self) -> None:
+ pass
+
+
+class RowMapStateProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle):
+ # Test string type schemas
+ self.map_state = handle.getMapState("mapState", "name string", "count int")
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ count = 0
+ key1 = ("key1",)
+ key2 = ("key2",)
+ for row in rows:
+ if row.temperature is not None:
+ count += 1
+ value1 = count
+ value2 = count
+ if self.map_state.exists():
+ if self.map_state.containsKey(key1):
+ value1 += self.map_state.getValue(key1)[0]
+ if self.map_state.containsKey(key2):
+ value2 += self.map_state.getValue(key2)[0]
+ self.map_state.updateValue(key1, (value1,))
+ self.map_state.updateValue(key2, (value2,))
+ key_iter = self.map_state.keys()
+ assert next(key_iter)[0] == "key1"
+ assert next(key_iter)[0] == "key2"
+ value_iter = self.map_state.values()
+ assert next(value_iter)[0] == value1
+ assert next(value_iter)[0] == value2
+ map_iter = self.map_state.iterator()
+ assert next(map_iter)[0] == key1
+ assert next(map_iter)[1] == (value2,)
+ self.map_state.removeKey(key1)
+ assert not self.map_state.containsKey(key1)
+ assert self.map_state.exists()
+ self.map_state.clear()
+ assert not self.map_state.exists()
+ yield Row(id=key[0], countAsString=str(count))
+
+ def close(self) -> None:
+ pass
+
+
+# A stateful processor that inherit all behavior of MapStateProcessor except that it use
+# ttl state with a large timeout.
+class PandasMapStateLargeTTLProcessor(PandasMapStateProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ key_schema = StructType([StructField("name", StringType(), True)])
+ value_schema = StructType([StructField("count", IntegerType(), True)])
+ self.map_state = handle.getMapState("mapState", key_schema, value_schema, 30000)
+ self.list_state = handle.getListState("listState", key_schema)
+
+
+# A stateful processor that inherit all behavior of MapStateProcessor except that it use
+# ttl state with a large timeout.
+class RowMapStateLargeTTLProcessor(RowMapStateProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ key_schema = StructType([StructField("name", StringType(), True)])
+ value_schema = StructType([StructField("count", IntegerType(), True)])
+ self.map_state = handle.getMapState("mapState", key_schema, value_schema, 30000)
+ self.list_state = handle.getListState("listState", key_schema)
+
+
+class PandasBasicProcessor(StatefulProcessor):
+ # Schema definitions
+ state_schema = StructType(
+ [StructField("id", IntegerType(), True), StructField("name", StringType(), True)]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+ self.state.update((id_val, name))
+ yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
+
+ def close(self) -> None:
+ pass
+
+
+class RowBasicProcessor(StatefulProcessor):
+ # Schema definitions
+ state_schema = StructType(
+ [StructField("id", IntegerType(), True), StructField("name", StringType(), True)]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+ self.state.update((id_val, name))
+ yield Row(id=key[0], value={"id": id_val, "name": name})
+
+ def close(self) -> None:
+ pass
+
+
+class PandasBasicProcessorNotNullable(StatefulProcessor):
+ # Schema definitions
+ state_schema = StructType(
+ [StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+ self.state.update((id_val, name))
+ yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
+
+ def close(self) -> None:
+ pass
+
+
+class RowBasicProcessorNotNullable(StatefulProcessor):
+ # Schema definitions
+ state_schema = StructType(
+ [StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+ self.state.update((id_val, name))
+ yield Row(id=key[0], value={"id": id_val, "name": name})
+
+ def close(self) -> None:
+ pass
+
+
+class PandasAddFieldsProcessor(StatefulProcessor):
+ state_schema = StructType(
+ [
+ StructField("id", IntegerType(), True),
+ StructField("name", StringType(), True),
+ StructField("count", IntegerType(), True),
+ StructField("active", BooleanType(), True),
+ StructField("score", FloatType(), True),
+ ]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+
+ if self.state.exists():
+ state_data = self.state.get()
+ state_dict = {
+ "id": state_data[0],
+ "name": state_data[1],
+ "count": state_data[2],
+ "active": state_data[3],
+ "score": state_data[4],
+ }
+ else:
+ state_dict = {
+ "id": id_val,
+ "name": name,
+ "count": 100,
+ "active": True,
+ "score": 99.9,
+ }
+
+ self.state.update(
+ (
+ state_dict["id"],
+ state_dict["name"] + "0",
+ state_dict["count"],
+ state_dict["active"],
+ state_dict["score"],
+ )
+ )
+ yield pd.DataFrame({"id": [key[0]], "value": [state_dict]})
+
+ def close(self) -> None:
+ pass
+
+
+class RowAddFieldsProcessor(StatefulProcessor):
+ state_schema = StructType(
+ [
+ StructField("id", IntegerType(), True),
+ StructField("name", StringType(), True),
+ StructField("count", IntegerType(), True),
+ StructField("active", BooleanType(), True),
+ StructField("score", FloatType(), True),
+ ]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+
+ if self.state.exists():
+ state_data = self.state.get()
+ state_dict = {
+ "id": state_data[0],
+ "name": state_data[1],
+ "count": state_data[2],
+ "active": state_data[3],
+ "score": state_data[4],
+ }
+ else:
+ state_dict = {
+ "id": id_val,
+ "name": name,
+ "count": 100,
+ "active": True,
+ "score": 99.9,
+ }
+
+ self.state.update(
+ (
+ state_dict["id"],
+ state_dict["name"] + "0",
+ state_dict["count"],
+ state_dict["active"],
+ state_dict["score"],
+ )
+ )
+ yield Row(id=key[0], value=state_dict)
+
+ def close(self) -> None:
+ pass
+
+
+class PandasRemoveFieldsProcessor(StatefulProcessor):
+ # Schema definitions
+ state_schema = StructType(
+ [StructField("id", IntegerType(), True), StructField("name", StringType(), True)]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+ if self.state.exists():
+ name = self.state.get()[1]
+ self.state.update((id_val, name))
+ yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
+
+ def close(self) -> None:
+ pass
+
+
+class RowRemoveFieldsProcessor(StatefulProcessor):
+ # Schema definitions
+ state_schema = StructType(
+ [StructField("id", IntegerType(), True), StructField("name", StringType(), True)]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+ if self.state.exists():
+ name = self.state.get()[1]
+ self.state.update((id_val, name))
+ yield Row(id=key[0], value={"id": id_val, "name": name})
+
+ def close(self) -> None:
+ pass
+
+
+class PandasReorderedFieldsProcessor(StatefulProcessor):
+ state_schema = StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("id", IntegerType(), True),
+ StructField("score", FloatType(), True),
+ StructField("count", IntegerType(), True),
+ StructField("active", BooleanType(), True),
+ ]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+
+ if self.state.exists():
+ state_data = self.state.get()
+ state_dict = {
+ "name": state_data[0],
+ "id": state_data[1],
+ "score": state_data[2],
+ "count": state_data[3],
+ "active": state_data[4],
+ }
+ else:
+ state_dict = {
+ "name": name,
+ "id": id_val,
+ "score": 99.9,
+ "count": 100,
+ "active": True,
+ }
+ self.state.update(
+ (
+ state_dict["name"],
+ state_dict["id"],
+ state_dict["score"],
+ state_dict["count"],
+ state_dict["active"],
+ )
+ )
+ yield pd.DataFrame({"id": [key[0]], "value": [state_dict]})
+
+ def close(self) -> None:
+ pass
+
+
+class RowReorderedFieldsProcessor(StatefulProcessor):
+ state_schema = StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("id", IntegerType(), True),
+ StructField("score", FloatType(), True),
+ StructField("count", IntegerType(), True),
+ StructField("active", BooleanType(), True),
+ ]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+
+ if self.state.exists():
+ state_data = self.state.get()
+ state_dict = {
+ "name": state_data[0],
+ "id": state_data[1],
+ "score": state_data[2],
+ "count": state_data[3],
+ "active": state_data[4],
+ }
+ else:
+ state_dict = {
+ "name": name,
+ "id": id_val,
+ "score": 99.9,
+ "count": 100,
+ "active": True,
+ }
+ self.state.update(
+ (
+ state_dict["name"],
+ state_dict["id"],
+ state_dict["score"],
+ state_dict["count"],
+ state_dict["active"],
+ )
+ )
+ yield Row(id=key[0], value=state_dict)
+
+ def close(self) -> None:
+ pass
+
+
+class PandasUpcastProcessor(StatefulProcessor):
+ state_schema = StructType(
+ [
+ StructField("id", LongType(), True), # Upcast from Int to Long
+ StructField("name", StringType(), True),
+ ]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+ if self.state.exists():
+ id_val += self.state.get()[0] + 1
+ self.state.update((id_val, name))
+ yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
+
+ def close(self) -> None:
+ pass
+
+
+class RowUpcastProcessor(StatefulProcessor):
+ state_schema = StructType(
+ [
+ StructField("id", LongType(), True), # Upcast from Int to Long
+ StructField("name", StringType(), True),
+ ]
+ )
+
+ def init(self, handle):
+ self.state = handle.getValueState("state", self.state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ for pdf in rows:
+ pass
+ id_val = int(key[0])
+ name = f"name-{id_val}"
+ if self.state.exists():
+ id_val += self.state.get()[0] + 1
+ self.state.update((id_val, name))
+ yield Row(id=key[0], value={"id": id_val, "name": name})
+
+ def close(self) -> None:
+ pass
+
+
+class PandasMinEventTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.min_state = handle.getValueState("min_state", state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
+ timestamp_list = []
+ for pdf in rows:
+ # int64 will represent timestamp in nanosecond, restore to second
+ timestamp_list.extend((pdf["eventTime"].astype("int64") // 10**9).tolist())
+
+ if self.min_state.exists():
+ cur_min = int(self.min_state.get()[0])
+ else:
+ cur_min = sys.maxsize
+ min_event_time = str(min(cur_min, min(timestamp_list)))
+
+ self.min_state.update((min_event_time,))
+
+ yield pd.DataFrame({"id": key, "timestamp": min_event_time})
+
+ def close(self) -> None:
+ pass
+
+
+class RowMinEventTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.min_state = handle.getValueState("min_state", state_schema)
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ timestamp_list = []
+ for row in rows:
+ # timestamp is microsecond, restore to second
+ timestamp_list.append(int(row.eventTime.timestamp()))
+
+ if self.min_state.exists():
+ cur_min = int(self.min_state.get()[0])
+ else:
+ cur_min = sys.maxsize
+ min_event_time = str(min(cur_min, min(timestamp_list)))
+
+ self.min_state.update((min_event_time,))
+
+ yield Row(id=key[0], timestamp=min_event_time)
+
+ def close(self) -> None:
+ pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 692f9705411e0..7e2221fc1a77b 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -24,6 +24,8 @@
from pyspark.sql import Row
from pyspark.sql.functions import col, encode, lit
from pyspark.errors import PythonException
+from pyspark.sql.session import SparkSession
+from pyspark.sql.types import StructType
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
@@ -42,6 +44,8 @@
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class MapInPandasTestsMixin:
+ spark: SparkSession
+
@staticmethod
def identity_dataframes_iter(*columns: str):
def func(iterator):
@@ -128,6 +132,27 @@ def func(iterator):
expected = df.collect()
self.assertEqual(actual, expected)
+ def test_not_null(self):
+ def func(iterator):
+ for _ in iterator:
+ yield pd.DataFrame({"a": [1, 2]})
+
+ schema = "a long not null"
+ df = self.spark.range(1).mapInPandas(func, schema)
+ self.assertEqual(df.schema, StructType.fromDDL(schema))
+ self.assertEqual(df.collect(), [Row(1), Row(2)])
+
+ def test_violate_not_null(self):
+ def func(iterator):
+ for _ in iterator:
+ yield pd.DataFrame({"a": [1, None]})
+
+ schema = "a long not null"
+ df = self.spark.range(1).mapInPandas(func, schema)
+ self.assertEqual(df.schema, StructType.fromDDL(schema))
+ with self.assertRaisesRegex(Exception, "is null"):
+ df.collect()
+
def test_different_output_length(self):
def func(iterator):
for _ in iterator:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 3257430d45e94..e36ae3a86a28b 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -15,18 +15,18 @@
# limitations under the License.
#
+from abc import abstractmethod
+
import json
import os
import time
import tempfile
-from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
-from typing import Iterator
+from pyspark.sql.streaming import StatefulProcessor
import unittest
from typing import cast
from pyspark import SparkConf
-from pyspark.errors import PySparkRuntimeError
from pyspark.sql.functions import array_sort, col, explode, split
from pyspark.sql.types import (
StringType,
@@ -35,9 +35,6 @@
Row,
IntegerType,
TimestampType,
- LongType,
- BooleanType,
- FloatType,
)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
@@ -48,28 +45,44 @@
pyarrow_requirement_message,
)
-if have_pandas:
- import pandas as pd
+from pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import (
+ SimpleStatefulProcessorWithInitialStateFactory,
+ StatefulProcessorWithInitialStateTimersFactory,
+ StatefulProcessorWithListStateInitialStateFactory,
+ EventTimeStatefulProcessorFactory,
+ ProcTimeStatefulProcessorFactory,
+ SimpleStatefulProcessorFactory,
+ StatefulProcessorChainingOpsFactory,
+ SimpleTTLStatefulProcessorFactory,
+ TTLStatefulProcessorFactory,
+ InvalidSimpleStatefulProcessorFactory,
+ ListStateProcessorFactory,
+ ListStateLargeListProcessorFactory,
+ ListStateLargeTTLProcessorFactory,
+ MapStateProcessorFactory,
+ MapStateLargeTTLProcessorFactory,
+ BasicProcessorFactory,
+ BasicProcessorNotNullableFactory,
+ AddFieldsProcessorFactory,
+ RemoveFieldsProcessorFactory,
+ ReorderedFieldsProcessorFactory,
+ UpcastProcessorFactory,
+ MinEventTimeStatefulProcessorFactory,
+)
-@unittest.skipIf(
- not have_pandas or not have_pyarrow,
- cast(str, pandas_requirement_message or pyarrow_requirement_message),
-)
-class TransformWithStateInPandasTestsMixin:
+class TransformWithStateTestsMixin:
@classmethod
- def conf(cls):
- cfg = SparkConf()
- cfg.set("spark.sql.shuffle.partitions", "5")
- cfg.set(
- "spark.sql.streaming.stateStore.providerClass",
- "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
- )
- cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch", "2")
- cfg.set("spark.sql.session.timeZone", "UTC")
- # TODO SPARK-49046 this config is to stop query from FEB sink gracefully
- cfg.set("spark.sql.streaming.noDataMicroBatches.enabled", "false")
- return cfg
+ @abstractmethod
+ def use_pandas(cls) -> bool:
+ ...
+
+ @classmethod
+ def get_processor(cls, stateful_processor_factory) -> StatefulProcessor:
+ if cls.use_pandas():
+ return stateful_processor_factory.pandas()
+ else:
+ return stateful_processor_factory.row()
def _prepare_input_data(self, input_path, col1, col2):
with open(input_path, "w") as fw:
@@ -117,9 +130,9 @@ def build_test_df_with_3_cols(self, input_path):
)
return df_final
- def _test_transform_with_state_in_pandas_basic(
+ def _test_transform_with_state_basic(
self,
- stateful_processor,
+ stateful_processor_factory,
check_results,
single_batch=False,
timeMode="None",
@@ -147,16 +160,26 @@ def _test_transform_with_state_in_pandas_basic(
]
)
- q = (
- df.groupBy("id")
- .transformWithStateInPandas(
+ stateful_processor = self.get_processor(stateful_processor_factory)
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode=timeMode,
+ initialState=initial_state,
+ )
+ else:
+ tws_df = df.groupBy("id").transformWithState(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
timeMode=timeMode,
initialState=initial_state,
)
- .writeStream.queryName("this_query")
+
+ q = (
+ tws_df.writeStream.queryName("this_query")
.option("checkpointLocation", checkpoint_path)
.foreachBatch(check_results)
.outputMode("update")
@@ -169,7 +192,7 @@ def _test_transform_with_state_in_pandas_basic(
q.awaitTermination(10)
self.assertTrue(q.exception() is None)
- def test_transform_with_state_in_pandas_basic(self):
+ def test_transform_with_state_basic(self):
def check_results(batch_df, batch_id):
if batch_id == 0:
assert set(batch_df.sort("id").collect()) == {
@@ -182,20 +205,20 @@ def check_results(batch_df, batch_id):
Row(id="1", countAsString="2"),
}
- self._test_transform_with_state_in_pandas_basic(SimpleStatefulProcessor(), check_results)
+ self._test_transform_with_state_basic(SimpleStatefulProcessorFactory(), check_results)
- def test_transform_with_state_in_pandas_non_exist_value_state(self):
+ def test_transform_with_state_non_exist_value_state(self):
def check_results(batch_df, _):
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="0"),
Row(id="1", countAsString="0"),
}
- self._test_transform_with_state_in_pandas_basic(
- InvalidSimpleStatefulProcessor(), check_results, True
+ self._test_transform_with_state_basic(
+ InvalidSimpleStatefulProcessorFactory(), check_results, True
)
- def test_transform_with_state_in_pandas_query_restarts(self):
+ def test_transform_with_state_query_restarts(self):
root_path = tempfile.mkdtemp()
input_path = root_path + "/input"
os.makedirs(input_path, exist_ok=True)
@@ -217,15 +240,24 @@ def test_transform_with_state_in_pandas_query_restarts(self):
]
)
- base_query = (
- df.groupBy("id")
- .transformWithStateInPandas(
- statefulProcessor=SimpleStatefulProcessor(),
+ stateful_processor = self.get_processor(SimpleStatefulProcessorFactory())
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
)
- .writeStream.queryName("this_query")
+ else:
+ tws_df = df.groupBy("id").transformWithState(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ )
+
+ base_query = (
+ tws_df.writeStream.queryName("this_query")
.format("parquet")
.outputMode("append")
.option("checkpointLocation", checkpoint_path)
@@ -260,46 +292,121 @@ def test_transform_with_state_in_pandas_query_restarts(self):
Row(id="1", countAsString="2"),
}
- def test_transform_with_state_in_pandas_list_state(self):
+ def test_transform_with_state_list_state(self):
def check_results(batch_df, _):
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
- self._test_transform_with_state_in_pandas_basic(ListStateProcessor(), check_results, True)
+ self._test_transform_with_state_basic(
+ ListStateProcessorFactory(), check_results, True, "processingTime"
+ )
+
+ def test_transform_with_state_list_state_large_list(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ expected_prev_elements = ""
+ expected_updated_elements = ",".join(map(lambda x: str(x), range(90)))
+ else:
+ # batch_id == 1:
+ expected_prev_elements = ",".join(map(lambda x: str(x), range(90)))
+ expected_updated_elements = ",".join(map(lambda x: str(x), range(180)))
+
+ assert set(batch_df.sort("id").collect()) == {
+ Row(
+ id="0",
+ prevElements=expected_prev_elements,
+ updatedElements=expected_updated_elements,
+ ),
+ Row(
+ id="1",
+ prevElements=expected_prev_elements,
+ updatedElements=expected_updated_elements,
+ ),
+ }
+
+ input_path = tempfile.mkdtemp()
+ checkpoint_path = tempfile.mkdtemp()
+
+ self._prepare_test_resource1(input_path)
+ time.sleep(2)
+ self._prepare_test_resource2(input_path)
+
+ df = self._build_test_df(input_path)
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("prevElements", StringType(), True),
+ StructField("updatedElements", StringType(), True),
+ ]
+ )
+
+ stateful_processor = self.get_processor(ListStateLargeListProcessorFactory())
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="none",
+ )
+ else:
+ tws_df = df.groupBy("id").transformWithState(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="none",
+ )
+
+ q = (
+ tws_df.writeStream.queryName("this_query")
+ .option("checkpointLocation", checkpoint_path)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+ self.assertEqual(q.name, "this_query")
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
# test list state with ttl has the same behavior as list state when state doesn't expire.
- def test_transform_with_state_in_pandas_list_state_large_ttl(self):
+ def test_transform_with_state_list_state_large_ttl(self):
def check_results(batch_df, batch_id):
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
- self._test_transform_with_state_in_pandas_basic(
- ListStateLargeTTLProcessor(), check_results, True, "processingTime"
+ self._test_transform_with_state_basic(
+ ListStateLargeTTLProcessorFactory(), check_results, True, "processingTime"
)
- def test_transform_with_state_in_pandas_map_state(self):
+ def test_transform_with_state_map_state(self):
def check_results(batch_df, _):
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
- self._test_transform_with_state_in_pandas_basic(MapStateProcessor(), check_results, True)
+ self._test_transform_with_state_basic(MapStateProcessorFactory(), check_results, True)
# test map state with ttl has the same behavior as map state when state doesn't expire.
- def test_transform_with_state_in_pandas_map_state_large_ttl(self):
+ def test_transform_with_state_map_state_large_ttl(self):
def check_results(batch_df, batch_id):
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
- self._test_transform_with_state_in_pandas_basic(
- MapStateLargeTTLProcessor(), check_results, True, "processingTime"
+ self._test_transform_with_state_basic(
+ MapStateLargeTTLProcessorFactory(), check_results, True, "processingTime"
)
# test value state with ttl has the same behavior as value state when
@@ -317,8 +424,8 @@ def check_results(batch_df, batch_id):
Row(id="1", countAsString="2"),
}
- self._test_transform_with_state_in_pandas_basic(
- SimpleTTLStatefulProcessor(), check_results, False, "processingTime"
+ self._test_transform_with_state_basic(
+ SimpleTTLStatefulProcessorFactory(), check_results, False, "processingTime"
)
# TODO SPARK-50908 holistic fix for TTL suite
@@ -394,18 +501,23 @@ def check_results(batch_df, batch_id):
]
)
- q = (
- df.groupBy("id")
- .transformWithStateInPandas(
- statefulProcessor=TTLStatefulProcessor(),
+ stateful_processor = self.get_processor(TTLStatefulProcessorFactory())
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
timeMode="processingTime",
)
- .writeStream.foreachBatch(check_results)
- .outputMode("update")
- .start()
- )
+ else:
+ tws_df = df.groupBy("id").transformWithState(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="processingTime",
+ )
+
+ q = tws_df.writeStream.foreachBatch(check_results).outputMode("update").start()
self.assertTrue(q.isActive)
q.processAllAvailable()
q.stop()
@@ -414,7 +526,7 @@ def check_results(batch_df, batch_id):
finally:
input_dir.cleanup()
- def _test_transform_with_state_in_pandas_proc_timer(self, stateful_processor, check_results):
+ def _test_transform_with_state_proc_timer(self, stateful_processor_factory, check_results):
input_path = tempfile.mkdtemp()
self._prepare_test_resource3(input_path)
time.sleep(2)
@@ -437,15 +549,24 @@ def _test_transform_with_state_in_pandas_proc_timer(self, stateful_processor, ch
)
query_name = "processing_time_test_query"
- q = (
- df.groupBy("id")
- .transformWithStateInPandas(
+ stateful_processor = self.get_processor(stateful_processor_factory)
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="processingtime",
+ )
+ else:
+ tws_df = df.groupBy("id").transformWithState(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
timeMode="processingtime",
)
- .writeStream.queryName(query_name)
+
+ q = (
+ tws_df.writeStream.queryName(query_name)
.foreachBatch(check_results)
.outputMode("update")
.start()
@@ -457,7 +578,7 @@ def _test_transform_with_state_in_pandas_proc_timer(self, stateful_processor, ch
q.awaitTermination(10)
self.assertTrue(q.exception() is None)
- def test_transform_with_state_in_pandas_proc_timer(self):
+ def test_transform_with_state_proc_timer(self):
def check_results(batch_df, batch_id):
# helper function to check expired timestamp is smaller than current processing time
def check_timestamp(batch_df):
@@ -497,11 +618,13 @@ def check_timestamp(batch_df):
Row(id="1", countAsString="5"),
}
- self._test_transform_with_state_in_pandas_proc_timer(
- ProcTimeStatefulProcessor(), check_results
+ self._test_transform_with_state_proc_timer(
+ ProcTimeStatefulProcessorFactory(), check_results
)
- def _test_transform_with_state_in_pandas_event_time(self, stateful_processor, check_results):
+ def _test_transform_with_state_event_time(
+ self, stateful_processor_factory, check_results, time_mode="eventtime"
+ ):
import pyspark.sql.functions as f
input_path = tempfile.mkdtemp()
@@ -516,6 +639,7 @@ def prepare_batch2(input_path):
def prepare_batch3(input_path):
with open(input_path + "/text-test2.txt", "w") as fw:
+ fw.write("a, 2\n")
fw.write("a, 11\n")
fw.write("a, 13\n")
fw.write("a, 15\n")
@@ -540,15 +664,24 @@ def prepare_batch3(input_path):
)
query_name = "event_time_test_query"
- q = (
- df.groupBy("id")
- .transformWithStateInPandas(
+ stateful_processor = self.get_processor(stateful_processor_factory)
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
- timeMode="eventtime",
+ timeMode=time_mode,
+ )
+ else:
+ tws_df = df.groupBy("id").transformWithState(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode=time_mode,
)
- .writeStream.queryName(query_name)
+
+ q = (
+ tws_df.writeStream.queryName(query_name)
.foreachBatch(check_results)
.outputMode("update")
.start()
@@ -560,7 +693,7 @@ def prepare_batch3(input_path):
q.awaitTermination(10)
self.assertTrue(q.exception() is None)
- def test_transform_with_state_in_pandas_event_time(self):
+ def test_transform_with_state_event_time(self):
def check_results(batch_df, batch_id):
if batch_id == 0:
# watermark for late event = 0
@@ -587,13 +720,42 @@ def check_results(batch_df, batch_id):
Row(id="a-expired", timestamp="10000"),
}
- self._test_transform_with_state_in_pandas_event_time(
- EventTimeStatefulProcessor(), check_results
+ self._test_transform_with_state_event_time(
+ EventTimeStatefulProcessorFactory(), check_results
)
- def _test_transform_with_state_init_state_in_pandas(
+ def test_transform_with_state_with_wmark_and_non_event_time(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ # watermark for late event = 0 and min event = 20
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="20"),
+ }
+ elif batch_id == 1:
+ # watermark for late event = 0 and min event = 4
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="4"),
+ }
+ elif batch_id == 2:
+ # watermark for late event = 10 and min event = 2 with no filtering
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="2"),
+ }
+ else:
+ for q in batch_df.sparkSession.streams.active:
+ q.stop()
+
+ self._test_transform_with_state_event_time(
+ MinEventTimeStatefulProcessorFactory(), check_results, "None"
+ )
+
+ self._test_transform_with_state_event_time(
+ MinEventTimeStatefulProcessorFactory(), check_results, "ProcessingTime"
+ )
+
+ def _test_transform_with_state_init_state(
self,
- stateful_processor,
+ stateful_processor_factory,
check_results,
time_mode="None",
checkpoint_path=None,
@@ -618,16 +780,26 @@ def _test_transform_with_state_init_state_in_pandas(
data = [("0", 789), ("3", 987)]
initial_state = self.spark.createDataFrame(data, "id string, initVal int").groupBy("id")
- q = (
- df.groupBy("id")
- .transformWithStateInPandas(
+ stateful_processor = self.get_processor(stateful_processor_factory)
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode=time_mode,
+ initialState=initial_state,
+ )
+ else:
+ tws_df = df.groupBy("id").transformWithState(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
timeMode=time_mode,
initialState=initial_state,
)
- .writeStream.queryName("this_query")
+
+ q = (
+ tws_df.writeStream.queryName("this_query")
.option("checkpointLocation", checkpoint_path)
.foreachBatch(check_results)
.outputMode("update")
@@ -640,7 +812,7 @@ def _test_transform_with_state_init_state_in_pandas(
q.awaitTermination(10)
self.assertTrue(q.exception() is None)
- def test_transform_with_state_init_state_in_pandas(self):
+ def test_transform_with_state_init_state(self):
def check_results(batch_df, batch_id):
if batch_id == 0:
# for key 0, initial state was processed and it was only processed once;
@@ -659,12 +831,12 @@ def check_results(batch_df, batch_id):
Row(id="3", value=str(987 + 12)),
}
- self._test_transform_with_state_init_state_in_pandas(
- SimpleStatefulProcessorWithInitialState(), check_results
+ self._test_transform_with_state_init_state(
+ SimpleStatefulProcessorWithInitialStateFactory(), check_results
)
def _test_transform_with_state_non_contiguous_grouping_cols(
- self, stateful_processor, check_results, initial_state=None
+ self, stateful_processor_factory, check_results, initial_state=None
):
input_path = tempfile.mkdtemp()
self._prepare_input_data_with_3_cols(
@@ -685,16 +857,26 @@ def _test_transform_with_state_non_contiguous_grouping_cols(
]
)
- q = (
- df.groupBy("id1", "id2")
- .transformWithStateInPandas(
+ stateful_processor = self.get_processor(stateful_processor_factory)
+ if self.use_pandas():
+ tws_df = df.groupBy("id1", "id2").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ initialState=initial_state,
+ )
+ else:
+ tws_df = df.groupBy("id1", "id2").transformWithState(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
initialState=initial_state,
)
- .writeStream.queryName("this_query")
+
+ q = (
+ tws_df.writeStream.queryName("this_query")
.foreachBatch(check_results)
.outputMode("update")
.start()
@@ -714,7 +896,7 @@ def check_results(batch_df, batch_id):
}
self._test_transform_with_state_non_contiguous_grouping_cols(
- SimpleStatefulProcessorWithInitialState(), check_results
+ SimpleStatefulProcessorWithInitialStateFactory(), check_results
)
def test_transform_with_state_non_contiguous_grouping_cols_with_init_state(self):
@@ -732,11 +914,15 @@ def check_results(batch_df, batch_id):
).groupBy("id1", "id2")
self._test_transform_with_state_non_contiguous_grouping_cols(
- SimpleStatefulProcessorWithInitialState(), check_results, initial_state
+ SimpleStatefulProcessorWithInitialStateFactory(), check_results, initial_state
)
- def _test_transform_with_state_in_pandas_chaining_ops(
- self, stateful_processor, check_results, timeMode="None", grouping_cols=["outputTimestamp"]
+ def _test_transform_with_state_chaining_ops(
+ self,
+ stateful_processor_factory,
+ check_results,
+ timeMode="None",
+ grouping_cols=["outputTimestamp"],
):
import pyspark.sql.functions as f
@@ -763,16 +949,26 @@ def _test_transform_with_state_in_pandas_chaining_ops(
]
)
- q = (
- df.groupBy("id")
- .transformWithStateInPandas(
+ stateful_processor = self.get_processor(stateful_processor_factory)
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Append",
timeMode=timeMode,
eventTimeColumnName="outputTimestamp",
)
- .groupBy(grouping_cols)
+ else:
+ tws_df = df.groupBy("id").transformWithState(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Append",
+ timeMode=timeMode,
+ eventTimeColumnName="outputTimestamp",
+ )
+
+ q = (
+ tws_df.groupBy(grouping_cols)
.count()
.writeStream.queryName("chaining_ops_query")
.foreachBatch(check_results)
@@ -785,7 +981,7 @@ def _test_transform_with_state_in_pandas_chaining_ops(
q.processAllAvailable()
q.awaitTermination(10)
- def test_transform_with_state_in_pandas_chaining_ops(self):
+ def test_transform_with_state_chaining_ops(self):
def check_results(batch_df, batch_id):
import datetime
@@ -810,11 +1006,14 @@ def check_results(batch_df, batch_id):
Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 15), count=1),
}
- self._test_transform_with_state_in_pandas_chaining_ops(
- StatefulProcessorChainingOps(), check_results, "eventTime"
+ self._test_transform_with_state_chaining_ops(
+ StatefulProcessorChainingOpsFactory(), check_results, "eventTime"
)
- self._test_transform_with_state_in_pandas_chaining_ops(
- StatefulProcessorChainingOps(), check_results, "eventTime", ["outputTimestamp", "id"]
+ self._test_transform_with_state_chaining_ops(
+ StatefulProcessorChainingOpsFactory(),
+ check_results,
+ "eventTime",
+ ["outputTimestamp", "id"],
)
def test_transform_with_state_init_state_with_timers(self):
@@ -842,11 +1041,11 @@ def check_results(batch_df, batch_id):
Row(id="3", value=str(987 + 12)),
}
- self._test_transform_with_state_init_state_in_pandas(
- StatefulProcessorWithInitialStateTimers(), check_results, "processingTime"
+ self._test_transform_with_state_init_state(
+ StatefulProcessorWithInitialStateTimersFactory(), check_results, "processingTime"
)
- def test_transform_with_state_in_pandas_batch_query(self):
+ def test_transform_with_state_batch_query(self):
data = [("0", 123), ("0", 46), ("1", 146), ("1", 346)]
df = self.spark.createDataFrame(data, "id string, temperature int")
@@ -856,18 +1055,28 @@ def test_transform_with_state_in_pandas_batch_query(self):
StructField("countAsString", StringType(), True),
]
)
- batch_result = df.groupBy("id").transformWithStateInPandas(
- statefulProcessor=MapStateProcessor(),
- outputStructType=output_schema,
- outputMode="Update",
- timeMode="None",
- )
+ stateful_processor = self.get_processor(MapStateProcessorFactory())
+ if self.use_pandas():
+ batch_result = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ )
+ else:
+ batch_result = df.groupBy("id").transformWithState(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ )
+
assert set(batch_result.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
- def test_transform_with_state_in_pandas_batch_query_initial_state(self):
+ def test_transform_with_state_batch_query_initial_state(self):
data = [("0", 123), ("0", 46), ("1", 146), ("1", 346)]
df = self.spark.createDataFrame(data, "id string, temperature int")
@@ -882,13 +1091,25 @@ def test_transform_with_state_in_pandas_batch_query_initial_state(self):
StructField("value", StringType(), True),
]
)
- batch_result = df.groupBy("id").transformWithStateInPandas(
- statefulProcessor=SimpleStatefulProcessorWithInitialState(),
- outputStructType=output_schema,
- outputMode="Update",
- timeMode="None",
- initialState=initial_state,
- )
+
+ stateful_processor = self.get_processor(SimpleStatefulProcessorWithInitialStateFactory())
+ if self.use_pandas():
+ batch_result = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ initialState=initial_state,
+ )
+ else:
+ batch_result = df.groupBy("id").transformWithState(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ initialState=initial_state,
+ )
+
assert set(batch_result.sort("id").collect()) == {
Row(id="0", value=str(789 + 123 + 46)),
Row(id="1", value=str(146 + 346)),
@@ -914,6 +1135,12 @@ def test_transform_with_map_state_metadata_with_init_state(self):
def _test_transform_with_map_state_metadata(self, initial_state):
checkpoint_path = tempfile.mktemp()
+ # This has to be outside of FEB to avoid serialization issues.
+ if self.use_pandas():
+ expected_operator_name = "transformWithStateInPandasExec"
+ else:
+ expected_operator_name = "transformWithStateInPySparkExec"
+
def check_results(batch_df, batch_id):
if batch_id == 0:
assert set(batch_df.sort("id").collect()) == {
@@ -925,6 +1152,7 @@ def check_results(batch_df, batch_id):
metadata_df = batch_df.sparkSession.read.format("state-metadata").load(
checkpoint_path
)
+
assert set(
metadata_df.select(
"operatorId",
@@ -937,7 +1165,7 @@ def check_results(batch_df, batch_id):
) == {
Row(
operatorId=0,
- operatorName="transformWithStateInPandasExec",
+ operatorName=expected_operator_name,
stateStoreName="default",
numPartitions=5,
minBatchId=0,
@@ -1016,8 +1244,8 @@ def check_results(batch_df, batch_id):
)
assert list_state_df.isEmpty()
- self._test_transform_with_state_in_pandas_basic(
- MapStateLargeTTLProcessor(),
+ self._test_transform_with_state_basic(
+ MapStateLargeTTLProcessorFactory(),
check_results,
True,
"processingTime",
@@ -1044,10 +1272,10 @@ def check_results(batch_df, batch_id):
metadata_df.select("operatorProperties").collect()[0][0]
)
state_var_list = operator_properties_json_obj["stateVariables"]
- assert len(state_var_list) == 3
+ assert len(state_var_list) == 4
for state_var in state_var_list:
- if state_var["stateName"] in ["listState1", "listState2"]:
- state_var["stateVariableType"] == "ListState"
+ if state_var["stateName"] in ["listState1", "listState2", "listStateTimestamp"]:
+ assert state_var["stateVariableType"] == "ListState"
else:
assert state_var["stateName"] == "$procTimers_keyToTimestamp"
assert state_var["stateVariableType"] == "TimerState"
@@ -1094,10 +1322,10 @@ def check_results(batch_df, batch_id):
Row(groupingKey="1", valueSortedList=[20, 20, 120, 120, 222]),
]
- self._test_transform_with_state_in_pandas_basic(
- ListStateProcessor(),
+ self._test_transform_with_state_basic(
+ ListStateProcessorFactory(),
check_results,
- True,
+ False,
"processingTime",
checkpoint_path=checkpoint_path,
initial_state=None,
@@ -1186,8 +1414,8 @@ def check_results(batch_df, batch_id):
with self.sql_conf(
{"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled": "true"}
):
- self._test_transform_with_state_in_pandas_basic(
- SimpleStatefulProcessor(),
+ self._test_transform_with_state_basic(
+ SimpleStatefulProcessorFactory(),
check_results,
False,
"processingTime",
@@ -1230,8 +1458,8 @@ def dataframe_to_value_list(output_df):
# run a tws query and read state data source dataframe from its checkpoint
checkpoint_path = tempfile.mkdtemp()
- self._test_transform_with_state_in_pandas_basic(
- ListStateProcessor(), check_results, True, checkpoint_path=checkpoint_path
+ self._test_transform_with_state_basic(
+ ListStateProcessorFactory(), check_results, True, checkpoint_path=checkpoint_path
)
list_state_df = (
self.spark.read.format("statestore")
@@ -1244,8 +1472,8 @@ def dataframe_to_value_list(output_df):
# run a new tws query and pass state data source dataframe as initial state
# multiple rows exist in the initial state with the same grouping key
new_checkpoint_path = tempfile.mkdtemp()
- self._test_transform_with_state_init_state_in_pandas(
- StatefulProcessorWithListStateInitialState(),
+ self._test_transform_with_state_init_state(
+ StatefulProcessorWithListStateInitialStateFactory(),
check_results_for_new_query,
checkpoint_path=new_checkpoint_path,
initial_state=init_df,
@@ -1255,11 +1483,12 @@ def dataframe_to_value_list(output_df):
def test_transform_with_state_with_timers_single_partition(self):
with self.sql_conf({"spark.sql.shuffle.partitions": "1"}):
self.test_transform_with_state_init_state_with_timers()
- self.test_transform_with_state_in_pandas_event_time()
- self.test_transform_with_state_in_pandas_proc_timer()
+ self.test_transform_with_state_event_time()
+ self.test_transform_with_state_proc_timer()
self.test_transform_with_state_restart_with_multiple_rows_init_state()
- def _run_evolution_test(self, processor, checkpoint_dir, check_results, df):
+ def _run_evolution_test(self, processor_factory, checkpoint_dir, check_results, df):
+ processor = self.get_processor(processor_factory)
output_schema = StructType(
[
StructField("id", StringType(), True),
@@ -1271,15 +1500,23 @@ def _run_evolution_test(self, processor, checkpoint_dir, check_results, df):
for q in self.spark.streams.active:
q.stop()
- q = (
- df.groupBy("id")
- .transformWithStateInPandas(
+ if self.use_pandas():
+ tws_df = df.groupBy("id").transformWithStateInPandas(
statefulProcessor=processor,
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
)
- .writeStream.queryName("evolution_test")
+ else:
+ tws_df = df.groupBy("id").transformWithState(
+ statefulProcessor=processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ )
+
+ q = (
+ tws_df.writeStream.queryName("evolution_test")
.option("checkpointLocation", checkpoint_dir)
.foreachBatch(check_results)
.outputMode("update")
@@ -1307,7 +1544,9 @@ def check_basic_state(batch_df, batch_id):
assert result.value["id"] == 0 # First ID from test data
assert result.value["name"] == "name-0"
- self._run_evolution_test(BasicProcessor(), checkpoint_dir, check_basic_state, df)
+ self._run_evolution_test(
+ BasicProcessorFactory(), checkpoint_dir, check_basic_state, df
+ )
self._prepare_test_resource2(input_path)
@@ -1320,7 +1559,9 @@ def check_add_fields(batch_df, batch_id):
assert result.value["active"] is None
assert result.value["score"] is None
- self._run_evolution_test(AddFieldsProcessor(), checkpoint_dir, check_add_fields, df)
+ self._run_evolution_test(
+ AddFieldsProcessorFactory(), checkpoint_dir, check_add_fields, df
+ )
self._prepare_test_resource3(input_path)
# Test 3: Remove fields
@@ -1330,7 +1571,7 @@ def check_remove_fields(batch_df, batch_id):
assert result.value["name"] == "name-00"
self._run_evolution_test(
- RemoveFieldsProcessor(), checkpoint_dir, check_remove_fields, df
+ RemoveFieldsProcessorFactory(), checkpoint_dir, check_remove_fields, df
)
self._prepare_test_resource4(input_path)
@@ -1341,7 +1582,7 @@ def check_reorder_fields(batch_df, batch_id):
assert result.value["id"] == 0
self._run_evolution_test(
- ReorderedFieldsProcessor(), checkpoint_dir, check_reorder_fields, df
+ ReorderedFieldsProcessorFactory(), checkpoint_dir, check_reorder_fields, df
)
self._prepare_test_resource5(input_path)
@@ -1351,7 +1592,7 @@ def check_upcast(batch_df, batch_id):
assert result.value["id"] == 1
assert result.value["name"] == "name-0"
- self._run_evolution_test(UpcastProcessor(), checkpoint_dir, check_upcast, df)
+ self._run_evolution_test(UpcastProcessorFactory(), checkpoint_dir, check_upcast, df)
# This test case verifies that an exception is thrown when downcasting, which violates
# Avro's schema evolution rules
@@ -1368,7 +1609,9 @@ def check_add_fields(batch_df, batch_id):
assert results[0].value["count"] == 100
assert results[0].value["active"]
- self._run_evolution_test(AddFieldsProcessor(), checkpoint_dir, check_add_fields, df)
+ self._run_evolution_test(
+ AddFieldsProcessorFactory(), checkpoint_dir, check_add_fields, df
+ )
self._prepare_test_resource2(input_path)
def check_upcast(batch_df, batch_id):
@@ -1376,7 +1619,7 @@ def check_upcast(batch_df, batch_id):
assert result.value["name"] == "name-0"
# Long
- self._run_evolution_test(UpcastProcessor(), checkpoint_dir, check_upcast, df)
+ self._run_evolution_test(UpcastProcessorFactory(), checkpoint_dir, check_upcast, df)
self._prepare_test_resource3(input_path)
def check_basic_state(batch_df, batch_id):
@@ -1387,7 +1630,7 @@ def check_basic_state(batch_df, batch_id):
# Int
try:
self._run_evolution_test(
- BasicProcessor(),
+ BasicProcessorFactory(),
checkpoint_dir,
check_basic_state,
df,
@@ -1420,7 +1663,7 @@ def check_basic_state(batch_df, batch_id):
try:
self._run_evolution_test(
- BasicProcessorNotNullable(),
+ BasicProcessorNotNullableFactory(),
checkpoint_dir,
check_basic_state,
df,
@@ -1438,603 +1681,62 @@ def check_basic_state(batch_df, batch_id):
)
-class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
- # this dict is the same as input initial state dataframe
- dict = {("0",): 789, ("3",): 987}
-
- def init(self, handle: StatefulProcessorHandle) -> None:
- state_schema = StructType([StructField("value", IntegerType(), True)])
- self.value_state = handle.getValueState("value_state", state_schema)
- self.handle = handle
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- exists = self.value_state.exists()
- if exists:
- value_row = self.value_state.get()
- existing_value = value_row[0]
- else:
- existing_value = 0
-
- accumulated_value = existing_value
-
- for pdf in rows:
- value = pdf["temperature"].astype(int).sum()
- accumulated_value += value
-
- self.value_state.update((accumulated_value,))
-
- if len(key) > 1:
- yield pd.DataFrame(
- {"id1": (key[0],), "id2": (key[1],), "value": str(accumulated_value)}
- )
- else:
- yield pd.DataFrame({"id": key, "value": str(accumulated_value)})
-
- def handleInitialState(self, key, initialState, timerValues) -> None:
- init_val = initialState.at[0, "initVal"]
- self.value_state.update((init_val,))
- if len(key) == 1:
- assert self.dict[key] == init_val
-
- def close(self) -> None:
- pass
-
-
-class StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState):
- def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
- self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
- str_key = f"{str(key[0])}-expired"
- yield pd.DataFrame({"id": (str_key,), "value": str(expiredTimerInfo.getExpiryTimeInMs())})
-
- def handleInitialState(self, key, initialState, timerValues) -> None:
- super().handleInitialState(key, initialState, timerValues)
- self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() - 1)
-
-
-class StatefulProcessorWithListStateInitialState(SimpleStatefulProcessorWithInitialState):
- def init(self, handle: StatefulProcessorHandle) -> None:
- super().init(handle)
- list_ele_schema = StructType([StructField("value", IntegerType(), True)])
- self.list_state = handle.getListState("list_state", list_ele_schema)
-
- def handleInitialState(self, key, initialState, timerValues) -> None:
- for val in initialState["initVal"].tolist():
- self.list_state.appendValue((val,))
-
-
-# A stateful processor that output the max event time it has seen. Register timer for
-# current watermark. Clear max state if timer expires.
-class EventTimeStatefulProcessor(StatefulProcessor):
- def init(self, handle: StatefulProcessorHandle) -> None:
- state_schema = StructType([StructField("value", StringType(), True)])
- self.handle = handle
- self.max_state = handle.getValueState("max_state", state_schema)
-
- def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
- self.max_state.clear()
- self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
- str_key = f"{str(key[0])}-expired"
- yield pd.DataFrame(
- {"id": (str_key,), "timestamp": str(expiredTimerInfo.getExpiryTimeInMs())}
- )
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- timestamp_list = []
- for pdf in rows:
- # int64 will represent timestamp in nanosecond, restore to second
- timestamp_list.extend((pdf["eventTime"].astype("int64") // 10**9).tolist())
-
- if self.max_state.exists():
- cur_max = int(self.max_state.get()[0])
- else:
- cur_max = 0
- max_event_time = str(max(cur_max, max(timestamp_list)))
-
- self.max_state.update((max_event_time,))
- self.handle.registerTimer(timerValues.getCurrentWatermarkInMs())
-
- yield pd.DataFrame({"id": key, "timestamp": max_event_time})
-
- def close(self) -> None:
- pass
-
-
-# A stateful processor that output the accumulation of count of input rows; register
-# processing timer and clear the counter if timer expires.
-class ProcTimeStatefulProcessor(StatefulProcessor):
- def init(self, handle: StatefulProcessorHandle) -> None:
- state_schema = StructType([StructField("value", StringType(), True)])
- self.handle = handle
- self.count_state = handle.getValueState("count_state", state_schema)
-
- def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
- # reset count state each time the timer is expired
- timer_list_1 = [e for e in self.handle.listTimers()]
- timer_list_2 = []
- idx = 0
- for e in self.handle.listTimers():
- timer_list_2.append(e)
- # check multiple iterator on the same grouping key works
- assert timer_list_2[idx] == timer_list_1[idx]
- idx += 1
-
- if len(timer_list_1) > 0:
- assert len(timer_list_1) == 2
- self.count_state.clear()
- self.handle.deleteTimer(expiredTimerInfo.getExpiryTimeInMs())
- yield pd.DataFrame(
- {
- "id": key,
- "countAsString": str("-1"),
- "timeValues": str(expiredTimerInfo.getExpiryTimeInMs()),
- }
- )
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- if not self.count_state.exists():
- count = 0
- else:
- count = int(self.count_state.get()[0])
-
- if key == ("0",):
- self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 1)
-
- rows_count = 0
- for pdf in rows:
- pdf_count = len(pdf)
- rows_count += pdf_count
-
- count = count + rows_count
-
- self.count_state.update((str(count),))
- timestamp = str(timerValues.getCurrentProcessingTimeInMs())
-
- yield pd.DataFrame({"id": key, "countAsString": str(count), "timeValues": timestamp})
-
- def close(self) -> None:
- pass
-
-
-class SimpleStatefulProcessor(StatefulProcessor, unittest.TestCase):
- dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
- batch_id = 0
-
- def init(self, handle: StatefulProcessorHandle) -> None:
- # Test both string type and struct type schemas
- self.num_violations_state = handle.getValueState("numViolations", "value int")
- state_schema = StructType([StructField("value", IntegerType(), True)])
- self.temp_state = handle.getValueState("tempState", state_schema)
- handle.deleteIfExists("tempState")
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"):
- self.temp_state.exists()
- new_violations = 0
- count = 0
- key_str = key[0]
- exists = self.num_violations_state.exists()
- if exists:
- existing_violations_row = self.num_violations_state.get()
- existing_violations = existing_violations_row[0]
- assert existing_violations == self.dict[0][key_str]
- self.batch_id = 1
- else:
- existing_violations = 0
- for pdf in rows:
- pdf_count = pdf.count()
- count += pdf_count.get("temperature")
- violations_pdf = pdf.loc[pdf["temperature"] > 100]
- new_violations += violations_pdf.count().get("temperature")
- updated_violations = new_violations + existing_violations
- assert updated_violations == self.dict[self.batch_id][key_str]
- self.num_violations_state.update((updated_violations,))
- yield pd.DataFrame({"id": key, "countAsString": str(count)})
-
- def close(self) -> None:
- pass
-
-
-class StatefulProcessorChainingOps(StatefulProcessor):
- def init(self, handle: StatefulProcessorHandle) -> None:
- pass
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- for pdf in rows:
- timestamp_list = pdf["eventTime"].tolist()
- yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]})
-
- def close(self) -> None:
- pass
-
-
-# A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use
-# ttl state with a large timeout.
-class SimpleTTLStatefulProcessor(SimpleStatefulProcessor, unittest.TestCase):
- def init(self, handle: StatefulProcessorHandle) -> None:
- state_schema = StructType([StructField("value", IntegerType(), True)])
- self.num_violations_state = handle.getValueState("numViolations", state_schema, 30000)
- self.temp_state = handle.getValueState("tempState", state_schema)
- handle.deleteIfExists("tempState")
-
-
-class TTLStatefulProcessor(StatefulProcessor):
- def init(self, handle: StatefulProcessorHandle) -> None:
- state_schema = StructType([StructField("value", IntegerType(), True)])
- user_key_schema = StructType([StructField("id", StringType(), True)])
- self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000)
- self.count_state = handle.getValueState("state", state_schema)
- self.ttl_list_state = handle.getListState("ttl-list-state", state_schema, 10000)
- self.ttl_map_state = handle.getMapState(
- "ttl-map-state", user_key_schema, state_schema, 10000
- )
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- count = 0
- ttl_count = 0
- ttl_list_state_count = 0
- ttl_map_state_count = 0
- id = key[0]
- if self.count_state.exists():
- count = self.count_state.get()[0]
- if self.ttl_count_state.exists():
- ttl_count = self.ttl_count_state.get()[0]
- if self.ttl_list_state.exists():
- iter = self.ttl_list_state.get()
- for s in iter:
- ttl_list_state_count += s[0]
- if self.ttl_map_state.exists():
- ttl_map_state_count = self.ttl_map_state.getValue(key)[0]
- for pdf in rows:
- pdf_count = pdf.count().get("temperature")
- count += pdf_count
- ttl_count += pdf_count
- ttl_list_state_count += pdf_count
- ttl_map_state_count += pdf_count
-
- self.count_state.update((count,))
- # skip updating state for the 2nd batch so that ttl state expire
- if not (ttl_count == 2 and id == "0"):
- self.ttl_count_state.update((ttl_count,))
- self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)])
- self.ttl_map_state.updateValue(key, (ttl_map_state_count,))
- yield pd.DataFrame(
- {
- "id": [
- f"ttl-count-{id}",
- f"count-{id}",
- f"ttl-list-state-count-{id}",
- f"ttl-map-state-count-{id}",
- ],
- "count": [ttl_count, count, ttl_list_state_count, ttl_map_state_count],
- }
- )
-
- def close(self) -> None:
- pass
-
-
-class InvalidSimpleStatefulProcessor(StatefulProcessor):
- def init(self, handle: StatefulProcessorHandle) -> None:
- state_schema = StructType([StructField("value", IntegerType(), True)])
- self.num_violations_state = handle.getValueState("numViolations", state_schema)
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- count = 0
- exists = self.num_violations_state.exists()
- assert not exists
- # try to get a state variable with no value
- assert self.num_violations_state.get() is None
- self.num_violations_state.clear()
- yield pd.DataFrame({"id": key, "countAsString": str(count)})
-
- def close(self) -> None:
- pass
-
-
-class ListStateProcessor(StatefulProcessor):
- # Dict to store the expected results. The key represents the grouping key string, and the value
- # is a dictionary of pandas dataframe index -> expected temperature value. Since we set
- # maxRecordsPerBatch to 2, we expect the pandas dataframe dictionary to have 2 entries.
- dict = {0: 120, 1: 20}
-
- def init(self, handle: StatefulProcessorHandle) -> None:
- state_schema = StructType([StructField("temperature", IntegerType(), True)])
- self.list_state1 = handle.getListState("listState1", state_schema)
- self.list_state2 = handle.getListState("listState2", state_schema)
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- count = 0
- for pdf in rows:
- list_state_rows = [(120,), (20,)]
- self.list_state1.put(list_state_rows)
- self.list_state2.put(list_state_rows)
- self.list_state1.appendValue((111,))
- self.list_state2.appendValue((222,))
- self.list_state1.appendList(list_state_rows)
- self.list_state2.appendList(list_state_rows)
- pdf_count = pdf.count()
- count += pdf_count.get("temperature")
- iter1 = self.list_state1.get()
- iter2 = self.list_state2.get()
- # Mixing the iterator to test it we can resume from the correct point
- assert next(iter1)[0] == self.dict[0]
- assert next(iter2)[0] == self.dict[0]
- assert next(iter1)[0] == self.dict[1]
- assert next(iter2)[0] == self.dict[1]
- # Get another iterator for list_state1 to test if the 2 iterators (iter1 and iter3) don't
- # interfere with each other.
- iter3 = self.list_state1.get()
- assert next(iter3)[0] == self.dict[0]
- assert next(iter3)[0] == self.dict[1]
- # the second arrow batch should contain the appended value 111 for list_state1 and
- # 222 for list_state2
- assert next(iter1)[0] == 111
- assert next(iter2)[0] == 222
- assert next(iter3)[0] == 111
- # since we put another 2 rows after 111/222, check them here
- assert next(iter1)[0] == self.dict[0]
- assert next(iter2)[0] == self.dict[0]
- assert next(iter3)[0] == self.dict[0]
- assert next(iter1)[0] == self.dict[1]
- assert next(iter2)[0] == self.dict[1]
- assert next(iter3)[0] == self.dict[1]
- yield pd.DataFrame({"id": key, "countAsString": str(count)})
-
- def close(self) -> None:
- pass
-
-
-class ListStateLargeTTLProcessor(ListStateProcessor):
- def init(self, handle: StatefulProcessorHandle) -> None:
- state_schema = StructType([StructField("temperature", IntegerType(), True)])
- self.list_state1 = handle.getListState("listState1", state_schema, 30000)
- self.list_state2 = handle.getListState("listState2", state_schema, 30000)
-
-
-class MapStateProcessor(StatefulProcessor):
- def init(self, handle: StatefulProcessorHandle):
- # Test string type schemas
- self.map_state = handle.getMapState("mapState", "name string", "count int")
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- count = 0
- key1 = ("key1",)
- key2 = ("key2",)
- for pdf in rows:
- pdf_count = pdf.count()
- count += pdf_count.get("temperature")
- value1 = count
- value2 = count
- if self.map_state.exists():
- if self.map_state.containsKey(key1):
- value1 += self.map_state.getValue(key1)[0]
- if self.map_state.containsKey(key2):
- value2 += self.map_state.getValue(key2)[0]
- self.map_state.updateValue(key1, (value1,))
- self.map_state.updateValue(key2, (value2,))
- key_iter = self.map_state.keys()
- assert next(key_iter)[0] == "key1"
- assert next(key_iter)[0] == "key2"
- value_iter = self.map_state.values()
- assert next(value_iter)[0] == value1
- assert next(value_iter)[0] == value2
- map_iter = self.map_state.iterator()
- assert next(map_iter)[0] == key1
- assert next(map_iter)[1] == (value2,)
- self.map_state.removeKey(key1)
- assert not self.map_state.containsKey(key1)
- yield pd.DataFrame({"id": key, "countAsString": str(count)})
-
- def close(self) -> None:
- pass
-
-
-# A stateful processor that inherit all behavior of MapStateProcessor except that it use
-# ttl state with a large timeout.
-class MapStateLargeTTLProcessor(MapStateProcessor):
- def init(self, handle: StatefulProcessorHandle) -> None:
- key_schema = StructType([StructField("name", StringType(), True)])
- value_schema = StructType([StructField("count", IntegerType(), True)])
- self.map_state = handle.getMapState("mapState", key_schema, value_schema, 30000)
- self.list_state = handle.getListState("listState", key_schema)
-
-
-class BasicProcessor(StatefulProcessor):
- # Schema definitions
- state_schema = StructType(
- [StructField("id", IntegerType(), True), StructField("name", StringType(), True)]
- )
-
- def init(self, handle):
- self.state = handle.getValueState("state", self.state_schema)
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- for pdf in rows:
- pass
- id_val = int(key[0])
- name = f"name-{id_val}"
- self.state.update((id_val, name))
- yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
-
- def close(self) -> None:
- pass
-
-
-class BasicProcessorNotNullable(StatefulProcessor):
- # Schema definitions
- state_schema = StructType(
- [StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
- )
-
- def init(self, handle):
- self.state = handle.getValueState("state", self.state_schema)
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- for pdf in rows:
- pass
- id_val = int(key[0])
- name = f"name-{id_val}"
- self.state.update((id_val, name))
- yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
-
- def close(self) -> None:
- pass
-
-
-class AddFieldsProcessor(StatefulProcessor):
- state_schema = StructType(
- [
- StructField("id", IntegerType(), True),
- StructField("name", StringType(), True),
- StructField("count", IntegerType(), True),
- StructField("active", BooleanType(), True),
- StructField("score", FloatType(), True),
- ]
- )
-
- def init(self, handle):
- self.state = handle.getValueState("state", self.state_schema)
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- for pdf in rows:
- pass
- id_val = int(key[0])
- name = f"name-{id_val}"
-
- if self.state.exists():
- state_data = self.state.get()
- state_dict = {
- "id": state_data[0],
- "name": state_data[1],
- "count": state_data[2],
- "active": state_data[3],
- "score": state_data[4],
- }
- else:
- state_dict = {
- "id": id_val,
- "name": name,
- "count": 100,
- "active": True,
- "score": 99.9,
- }
+@unittest.skipIf(
+ not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",
+ cast(str, pyarrow_requirement_message or "Not supported in no-GIL mode"),
+)
+class TransformWithStateInPySparkTestsMixin(TransformWithStateTestsMixin):
+ @classmethod
+ def use_pandas(cls) -> bool:
+ return False
- self.state.update(
- (
- state_dict["id"],
- state_dict["name"] + "0",
- state_dict["count"],
- state_dict["active"],
- state_dict["score"],
- )
+ @classmethod
+ def conf(cls):
+ cfg = SparkConf()
+ cfg.set("spark.sql.shuffle.partitions", "5")
+ cfg.set(
+ "spark.sql.streaming.stateStore.providerClass",
+ "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
)
- yield pd.DataFrame({"id": [key[0]], "value": [state_dict]})
-
- def close(self) -> None:
- pass
-
+ cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch", "2")
+ cfg.set("spark.sql.session.timeZone", "UTC")
+ # TODO SPARK-49046 this config is to stop query from FEB sink gracefully
+ cfg.set("spark.sql.streaming.noDataMicroBatches.enabled", "false")
+ return cfg
-class RemoveFieldsProcessor(StatefulProcessor):
- # Schema definitions
- state_schema = StructType(
- [StructField("id", IntegerType(), True), StructField("name", StringType(), True)]
- )
- def init(self, handle):
- self.state = handle.getValueState("state", self.state_schema)
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- for pdf in rows:
- pass
- id_val = int(key[0])
- name = f"name-{id_val}"
- if self.state.exists():
- name = self.state.get()[1]
- self.state.update((id_val, name))
- yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
-
- def close(self) -> None:
- pass
-
-
-class ReorderedFieldsProcessor(StatefulProcessor):
- state_schema = StructType(
- [
- StructField("name", StringType(), True),
- StructField("id", IntegerType(), True),
- StructField("score", FloatType(), True),
- StructField("count", IntegerType(), True),
- StructField("active", BooleanType(), True),
- ]
- )
+@unittest.skipIf(
+ not have_pandas or not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",
+ cast(
+ str,
+ pandas_requirement_message or pyarrow_requirement_message or "Not supported in no-GIL mode",
+ ),
+)
+class TransformWithStateInPandasTestsMixin(TransformWithStateTestsMixin):
+ @classmethod
+ def use_pandas(cls) -> bool:
+ return True
- def init(self, handle):
- self.state = handle.getValueState("state", self.state_schema)
-
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- for pdf in rows:
- pass
- id_val = int(key[0])
- name = f"name-{id_val}"
-
- if self.state.exists():
- state_data = self.state.get()
- state_dict = {
- "name": state_data[0],
- "id": state_data[1],
- "score": state_data[2],
- "count": state_data[3],
- "active": state_data[4],
- }
- else:
- state_dict = {
- "name": name,
- "id": id_val,
- "score": 99.9,
- "count": 100,
- "active": True,
- }
- self.state.update(
- (
- state_dict["name"],
- state_dict["id"],
- state_dict["score"],
- state_dict["count"],
- state_dict["active"],
- )
+ @classmethod
+ def conf(cls):
+ cfg = SparkConf()
+ cfg.set("spark.sql.shuffle.partitions", "5")
+ cfg.set(
+ "spark.sql.streaming.stateStore.providerClass",
+ "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
)
- yield pd.DataFrame({"id": [key[0]], "value": [state_dict]})
-
- def close(self) -> None:
- pass
-
-
-class UpcastProcessor(StatefulProcessor):
- state_schema = StructType(
- [
- StructField("id", LongType(), True), # Upcast from Int to Long
- StructField("name", StringType(), True),
- ]
- )
-
- def init(self, handle):
- self.state = handle.getValueState("state", self.state_schema)
+ cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch", "2")
+ cfg.set("spark.sql.session.timeZone", "UTC")
+ # TODO SPARK-49046 this config is to stop query from FEB sink gracefully
+ cfg.set("spark.sql.streaming.noDataMicroBatches.enabled", "false")
+ return cfg
- def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
- for pdf in rows:
- pass
- id_val = int(key[0])
- name = f"name-{id_val}"
- if self.state.exists():
- id_val += self.state.get()[0] + 1
- self.state.update((id_val, name))
- yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
- def close(self) -> None:
- pass
+class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase):
+ pass
-class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase):
+class TransformWithStateInPySparkTests(TransformWithStateInPySparkTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 706b8c0a8be81..e20ab09d8d1e1 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -26,7 +26,7 @@
from contextlib import redirect_stdout
from pyspark.sql import Row, functions, DataFrame
-from pyspark.sql.functions import col, lit, count, struct
+from pyspark.sql.functions import col, lit, count, struct, date_format, to_date, array, explode
from pyspark.sql.types import (
StringType,
IntegerType,
@@ -1076,6 +1076,32 @@ def test_metadata_column(self):
[Row(0), Row(0), Row(0)],
)
+ def test_with_column_and_generator(self):
+ # SPARK-51451: Generators should be available with withColumn
+ df = self.spark.createDataFrame([("082017",)], ["dt"]).select(
+ to_date(col("dt"), "MMyyyy").alias("dt")
+ )
+ df_dt = df.withColumn("dt", date_format(col("dt"), "MM/dd/yyyy"))
+ monthArray = [lit(x) for x in range(0, 12)]
+ df_month_y = df_dt.withColumn("month_y", explode(array(monthArray)))
+
+ assertDataFrameEqual(
+ df_month_y,
+ [Row(dt="08/01/2017", month_y=i) for i in range(12)],
+ )
+
+ df_dt_month_y = df.withColumns(
+ {
+ "dt": date_format(col("dt"), "MM/dd/yyyy"),
+ "month_y": explode(array(monthArray)),
+ }
+ )
+
+ assertDataFrameEqual(
+ df_dt_month_y,
+ [Row(dt="08/01/2017", month_y=i) for i in range(12)],
+ )
+
class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
def test_query_execution_unsupported_in_classic(self):
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index b627bc793f05a..a95bdcb8e507e 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -18,7 +18,7 @@
from contextlib import redirect_stdout
import datetime
from enum import Enum
-from inspect import getmembers, isfunction
+from inspect import getmembers, isfunction, isclass
import io
from itertools import chain
import math
@@ -65,9 +65,6 @@ def test_function_parity(self):
"any", # equivalent to python ~some
"len", # equivalent to python ~length
"udaf", # used for creating UDAF's which are not supported in PySpark
- "random", # namespace conflict with python built-in module
- "uuid", # namespace conflict with python built-in module
- "chr", # namespace conflict with python built-in function
"partitioning$", # partitioning expressions for DSv2
]
@@ -90,6 +87,76 @@ def test_function_parity(self):
expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected"
)
+ def test_wildcard_import(self):
+ all_set = set(F.__all__)
+
+ # {
+ # "abs",
+ # "acos",
+ # "acosh",
+ # "add_months",
+ # "aes_decrypt",
+ # "aes_encrypt",
+ # ...,
+ # }
+ fn_set = {
+ name
+ for (name, value) in getmembers(F, isfunction)
+ if name[0] != "_" and value.__module__ != "typing"
+ }
+
+ deprecated_fn_list = [
+ "approxCountDistinct", # deprecated
+ "bitwiseNOT", # deprecated
+ "countDistinct", # deprecated
+ "chr", # name conflict with builtin function
+ "random", # name conflict with builtin function
+ "shiftLeft", # deprecated
+ "shiftRight", # deprecated
+ "shiftRightUnsigned", # deprecated
+ "sumDistinct", # deprecated
+ "toDegrees", # deprecated
+ "toRadians", # deprecated
+ "uuid", # name conflict with builtin module
+ ]
+ unregistered_fn_list = [
+ "chr", # name conflict with builtin function
+ "random", # name conflict with builtin function
+ "uuid", # name conflict with builtin module
+ ]
+ expected_fn_all_diff = set(deprecated_fn_list + unregistered_fn_list)
+ self.assertEqual(expected_fn_all_diff, fn_set - all_set)
+
+ # {
+ # "AnalyzeArgument",
+ # "AnalyzeResult",
+ # ...,
+ # "UserDefinedFunction",
+ # "UserDefinedTableFunction",
+ # }
+ clz_set = {
+ name
+ for (name, value) in getmembers(F, isclass)
+ if name[0] != "_" and value.__module__ != "typing"
+ }
+
+ expected_clz_all_diff = {
+ "ArrayType", # should be imported from pyspark.sql.types
+ "ByteType", # should be imported from pyspark.sql.types
+ "Column", # should be imported from pyspark.sql
+ "DataType", # should be imported from pyspark.sql.types
+ "NumericType", # should be imported from pyspark.sql.types
+ "ParentDataFrame", # internal class
+ "PySparkTypeError", # should be imported from pyspark.errors
+ "PySparkValueError", # should be imported from pyspark.errors
+ "StringType", # should be imported from pyspark.sql.types
+ "StructType", # should be imported from pyspark.sql.types
+ }
+ self.assertEqual(expected_clz_all_diff, clz_set - all_set)
+
+ unknonw_set = all_set - (fn_set | clz_set)
+ self.assertEqual(unknonw_set, set())
+
def test_explode(self):
d = [
Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py
index 34299bdb7740c..5f654ce6cfaea 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -18,30 +18,50 @@
import platform
import tempfile
import unittest
-from typing import Callable, Union
+from datetime import datetime
+from decimal import Decimal
+from typing import Callable, Iterable, List, Union
-from pyspark.errors import PythonException, AnalysisException
+from pyspark.errors import AnalysisException, PythonException
from pyspark.sql.datasource import (
+ CaseInsensitiveDict,
DataSource,
+ DataSourceArrowWriter,
DataSourceReader,
- InputPartition,
DataSourceWriter,
- DataSourceArrowWriter,
+ EqualNullSafe,
+ EqualTo,
+ Filter,
+ GreaterThan,
+ GreaterThanOrEqual,
+ In,
+ InputPartition,
+ IsNotNull,
+ IsNull,
+ LessThan,
+ LessThanOrEqual,
+ Not,
+ StringContains,
+ StringEndsWith,
+ StringStartsWith,
WriterCommitMessage,
- CaseInsensitiveDict,
)
from pyspark.sql.functions import spark_partition_id
+from pyspark.sql.session import SparkSession
from pyspark.sql.types import Row, StructType
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
+ SPARK_HOME,
+ ReusedSQLTestCase,
have_pyarrow,
pyarrow_requirement_message,
)
-from pyspark.testing import assertDataFrameEqual
-from pyspark.testing.sqlutils import ReusedSQLTestCase, SPARK_HOME
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class BasePythonDataSourceTestsMixin:
+ spark: SparkSession
+
def test_basic_data_source_class(self):
class MyDataSource(DataSource):
...
@@ -246,6 +266,209 @@ def reader(self, schema) -> "DataSourceReader":
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)
+ def test_filter_pushdown(self):
+ class TestDataSourceReader(DataSourceReader):
+ def __init__(self):
+ self.has_filter = False
+
+ def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+ assert set(filters) == {
+ IsNotNull(("x",)),
+ IsNotNull(("y",)),
+ EqualTo(("x",), 1),
+ EqualTo(("y",), 2),
+ }, filters
+ self.has_filter = True
+ # pretend we support x = 1 filter but in fact we don't
+ # so we only return y = 2 filter
+ yield filters[filters.index(EqualTo(("y",), 2))]
+
+ def partitions(self):
+ assert self.has_filter
+ return super().partitions()
+
+ def read(self, partition):
+ assert self.has_filter
+ yield [1, 1]
+ yield [1, 2]
+ yield [2, 2]
+
+ class TestDataSource(DataSource):
+ @classmethod
+ def name(cls):
+ return "test"
+
+ def schema(self):
+ return "x int, y int"
+
+ def reader(self, schema) -> "DataSourceReader":
+ return TestDataSourceReader()
+
+ with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
+ self.spark.dataSource.register(TestDataSource)
+ df = self.spark.read.format("test").load().filter("x = 1 and y = 2")
+ # only the y = 2 filter is applied post scan
+ assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)])
+
+ def test_extraneous_filter(self):
+ class TestDataSourceReader(DataSourceReader):
+ def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+ yield EqualTo(("x",), 1)
+
+ def partitions(self):
+ assert False
+
+ def read(self, partition):
+ assert False
+
+ class TestDataSource(DataSource):
+ @classmethod
+ def name(cls):
+ return "test"
+
+ def schema(self):
+ return "x int"
+
+ def reader(self, schema) -> "DataSourceReader":
+ return TestDataSourceReader()
+
+ with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
+ self.spark.dataSource.register(TestDataSource)
+ with self.assertRaisesRegex(Exception, "DATA_SOURCE_EXTRANEOUS_FILTERS"):
+ self.spark.read.format("test").load().filter("x = 1").show()
+
+ def test_filter_pushdown_error(self):
+ error_str = "dummy error"
+
+ class TestDataSourceReader(DataSourceReader):
+ def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+ raise Exception(error_str)
+
+ def read(self, partition):
+ yield [1]
+
+ class TestDataSource(DataSource):
+ def schema(self):
+ return "x int"
+
+ def reader(self, schema) -> "DataSourceReader":
+ return TestDataSourceReader()
+
+ with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
+ self.spark.dataSource.register(TestDataSource)
+ df = self.spark.read.format("TestDataSource").load().filter("x = 1 or x is null")
+ assertDataFrameEqual(df, [Row(x=1)]) # works when not pushing down filters
+ with self.assertRaisesRegex(Exception, error_str):
+ df.filter("x = 1").explain()
+
+ def test_filter_pushdown_disabled(self):
+ class TestDataSourceReader(DataSourceReader):
+ def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+ assert False
+
+ def read(self, partition):
+ assert False
+
+ class TestDataSource(DataSource):
+ def reader(self, schema) -> "DataSourceReader":
+ return TestDataSourceReader()
+
+ with self.sql_conf({"spark.sql.python.filterPushdown.enabled": False}):
+ self.spark.dataSource.register(TestDataSource)
+ df = self.spark.read.format("TestDataSource").schema("x int").load()
+ with self.assertRaisesRegex(Exception, "DATA_SOURCE_PUSHDOWN_DISABLED"):
+ df.show()
+
+ def _check_filters(self, sql_type, sql_filter, python_filters):
+ """
+ Parameters
+ ----------
+ sql_type: str
+ The SQL type of the column x.
+ sql_filter: str
+ A SQL filter using the column x.
+ python_filters: List[Filter]
+ The expected python filters to be pushed down.
+ """
+
+ class TestDataSourceReader(DataSourceReader):
+ def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+ actual = [f for f in filters if not isinstance(f, IsNotNull)]
+ expected = python_filters
+ assert actual == expected, (actual, expected)
+ return filters
+
+ def read(self, partition):
+ yield from []
+
+ class TestDataSource(DataSource):
+ def schema(self):
+ return f"x {sql_type}"
+
+ def reader(self, schema) -> "DataSourceReader":
+ return TestDataSourceReader()
+
+ with self.sql_conf({"spark.sql.python.filterPushdown.enabled": True}):
+ self.spark.dataSource.register(TestDataSource)
+ df = self.spark.read.format("TestDataSource").load().filter(sql_filter)
+ df.count()
+
+ def test_unsupported_filter(self):
+ self._check_filters(
+ "struct", "x.a = 1 and x.b = x.c", [EqualTo(("x", "a"), 1)]
+ )
+ self._check_filters("int", "x = 1 or x > 2", [])
+ self._check_filters("int", "(0 < x and x < 1) or x = 2", [])
+ self._check_filters("int", "x % 5 = 1", [])
+ self._check_filters("array", "x[0] = 1", [])
+ self._check_filters("string", "x like 'a%a%'", [])
+ self._check_filters("string", "x ilike 'a'", [])
+ self._check_filters("string", "x = 'a' collate zh", [])
+
+ def test_filter_value_type(self):
+ self._check_filters("int", "x = 1", [EqualTo(("x",), 1)])
+ self._check_filters("int", "x = null", [EqualTo(("x",), None)])
+ self._check_filters("float", "x = 3 / 2", [EqualTo(("x",), 1.5)])
+ self._check_filters("string", "x = '1'", [EqualTo(("x",), "1")])
+ self._check_filters("array", "x = array(1, 2)", [EqualTo(("x",), [1, 2])])
+ self._check_filters(
+ "struct", "x = named_struct('x', 1)", [EqualTo(("x",), {"x": 1})]
+ )
+ self._check_filters(
+ "decimal", "x in (1.1, 2.1)", [In(("x",), [Decimal(1.1), Decimal(2.1)])]
+ )
+ self._check_filters(
+ "timestamp_ntz",
+ "x = timestamp_ntz '2020-01-01 00:00:00'",
+ [EqualTo(("x",), datetime.strptime("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"))],
+ )
+ self._check_filters(
+ "interval second",
+ "x = interval '2' second",
+ [], # intervals are not supported
+ )
+
+ def test_filter_type(self):
+ self._check_filters("boolean", "x", [EqualTo(("x",), True)])
+ self._check_filters("boolean", "not x", [Not(EqualTo(("x",), True))])
+ self._check_filters("int", "x is null", [IsNull(("x",))])
+ self._check_filters("int", "x <> 0", [Not(EqualTo(("x",), 0))])
+ self._check_filters("int", "x <=> 1", [EqualNullSafe(("x",), 1)])
+ self._check_filters("int", "1 < x", [GreaterThan(("x",), 1)])
+ self._check_filters("int", "1 <= x", [GreaterThanOrEqual(("x",), 1)])
+ self._check_filters("int", "x < 1", [LessThan(("x",), 1)])
+ self._check_filters("int", "x <= 1", [LessThanOrEqual(("x",), 1)])
+ self._check_filters("string", "x like 'a%'", [StringStartsWith(("x",), "a")])
+ self._check_filters("string", "x like '%a'", [StringEndsWith(("x",), "a")])
+ self._check_filters("string", "x like '%a%'", [StringContains(("x",), "a")])
+ self._check_filters(
+ "string", "x like 'a%b'", [StringStartsWith(("x",), "a"), StringEndsWith(("x",), "b")]
+ )
+ self._check_filters("int", "x in (1, 2)", [In(("x",), [1, 2])])
+
+ def test_filter_nested_column(self):
+ self._check_filters("struct", "x.y = 1", [EqualTo(("x", "y"), 1)])
+
def _get_test_json_data_source(self):
import json
import os
diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py
index 01cf3c51d7de0..eab1ad043ef33 100644
--- a/python/pyspark/sql/tests/test_serde.py
+++ b/python/pyspark/sql/tests/test_serde.py
@@ -23,7 +23,8 @@
from pyspark.sql import Row
from pyspark.sql.functions import lit
from pyspark.sql.types import StructType, StructField, DecimalType, BinaryType
-from pyspark.testing.sqlutils import ReusedSQLTestCase, UTCOffsetTimezone
+from pyspark.testing.objects import UTCOffsetTimezone
+from pyspark.testing.sqlutils import ReusedSQLTestCase
class SerdeTestsMixin:
@@ -82,9 +83,6 @@ def test_time_with_timezone(self):
day = datetime.date.today()
now = datetime.datetime.now()
ts = time.mktime(now.timetuple())
- # class in __main__ is not serializable
- from pyspark.testing.sqlutils import UTCOffsetTimezone
-
utc = UTCOffsetTimezone()
utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds
# add microseconds to utcnow (keeping year,month,day,hour,minute,second)
diff --git a/python/pyspark/sql/tests/test_subquery.py b/python/pyspark/sql/tests/test_subquery.py
index 7c63ddb69458e..7c87f4b46cc69 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -28,7 +28,7 @@ class SubqueryTestsMixin:
def df1(self):
return self.spark.createDataFrame(
[
- (1, 1.0),
+ (1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 2.0),
@@ -459,6 +459,211 @@ def test_exists_subquery(self):
),
)
+ def test_in_subquery(self):
+ with self.tempView("l", "r", "t"):
+ self.df1.createOrReplaceTempView("l")
+ self.df2.createOrReplaceTempView("r")
+ self.spark.table("r").filter(
+ sf.col("c").isNotNull() & sf.col("d").isNotNull()
+ ).createOrReplaceTempView("t")
+
+ with self.subTest("IN"):
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ sf.col("l.a").isin(self.spark.table("r").select(sf.col("c")))
+ ),
+ self.spark.sql("""select * from l where l.a in (select c from r)"""),
+ )
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ sf.col("l.a").isin(
+ self.spark.table("r")
+ .where(sf.col("l.b").outer() < sf.col("r.d"))
+ .select(sf.col("c"))
+ )
+ ),
+ self.spark.sql(
+ """select * from l where l.a in (select c from r where l.b < r.d)"""
+ ),
+ )
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ sf.col("l.a").isin(self.spark.table("r").select("c"))
+ & (sf.col("l.a") > sf.lit(2))
+ & sf.col("l.b").isNotNull()
+ ),
+ self.spark.sql(
+ """
+ select * from l
+ where l.a in (select c from r) and l.a > 2 and l.b is not null
+ """
+ ),
+ )
+
+ with self.subTest("IN with struct"), self.tempView("ll", "rr"):
+ self.spark.table("l").select(
+ "*", sf.struct("a", "b").alias("sab")
+ ).createOrReplaceTempView("ll")
+ self.spark.table("r").select(
+ "*", sf.struct(sf.col("c").alias("a"), sf.col("d").alias("b")).alias("scd")
+ ).createOrReplaceTempView("rr")
+
+ for col, values in [
+ (sf.col("sab"), "sab"),
+ (sf.struct(sf.struct(sf.col("a"), sf.col("b"))), "struct(struct(a, b))"),
+ ]:
+ for df, query in [
+ (self.spark.table("rr").select(sf.col("scd")), "select scd from rr"),
+ (
+ self.spark.table("rr").select(
+ sf.struct(sf.col("c").alias("a"), sf.col("d").alias("b"))
+ ),
+ "select struct(c as a, d as b) from rr",
+ ),
+ (
+ self.spark.table("rr").select(sf.struct(sf.col("c"), sf.col("d"))),
+ "select struct(c, d) from rr",
+ ),
+ ]:
+ sql_query = f"""select a, b from ll where {values} in ({query})"""
+ with self.subTest(sql_query=sql_query):
+ assertDataFrameEqual(
+ self.spark.table("ll").where(col.isin(df)).select("a", "b"),
+ self.spark.sql(sql_query),
+ )
+
+ with self.subTest("NOT IN"):
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ ~sf.col("a").isin(self.spark.table("r").select("c"))
+ ),
+ self.spark.sql("""select * from l where a not in (select c from r)"""),
+ )
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ ~sf.col("a").isin(
+ self.spark.table("r").where(sf.col("c").isNotNull()).select(sf.col("c"))
+ )
+ ),
+ self.spark.sql(
+ """select * from l where a not in (select c from r where c is not null)"""
+ ),
+ )
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ (
+ ~sf.struct(sf.col("a"), sf.col("b")).isin(
+ self.spark.table("t").select(sf.col("c"), sf.col("d"))
+ )
+ )
+ & (sf.col("a") < sf.lit(4))
+ ),
+ self.spark.sql(
+ """select * from l where (a, b) not in (select c, d from t) and a < 4"""
+ ),
+ )
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ ~sf.struct(sf.col("a"), sf.col("b")).isin(
+ self.spark.table("r")
+ .where(sf.col("c") > sf.lit(10))
+ .select(sf.col("c"), sf.col("d"))
+ )
+ ),
+ self.spark.sql(
+ """select * from l where (a, b) not in (select c, d from r where c > 10)"""
+ ),
+ )
+
+ with self.subTest("IN within OR"):
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ sf.col("l.a").isin(self.spark.table("r").select("c"))
+ | (
+ sf.col("l.a").isin(
+ self.spark.table("r")
+ .where(sf.col("l.b").outer() < sf.col("r.d"))
+ .select(sf.col("c"))
+ )
+ )
+ ),
+ self.spark.sql(
+ """
+ select * from l
+ where l.a in (select c from r) or l.a in (select c from r where l.b < r.d)
+ """
+ ),
+ )
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ (~sf.col("a").isin(self.spark.table("r").select(sf.col("c"))))
+ | (
+ ~sf.col("a").isin(
+ self.spark.table("r")
+ .where(sf.col("c").isNotNull())
+ .select(sf.col("c"))
+ )
+ )
+ ),
+ self.spark.sql(
+ """
+ select * from l
+ where a not in (select c from r)
+ or a not in (select c from r where c is not null)
+ """
+ ),
+ )
+
+ with self.subTest("complex IN"):
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ ~sf.struct(sf.col("a"), sf.col("b")).isin(
+ self.spark.table("r").select(sf.col("c"), sf.col("d"))
+ )
+ ),
+ self.spark.sql("""select * from l where (a, b) not in (select c, d from r)"""),
+ )
+ assertDataFrameEqual(
+ self.spark.table("l").where(
+ (
+ ~sf.struct(sf.col("a"), sf.col("b")).isin(
+ self.spark.table("t").select(sf.col("c"), sf.col("d"))
+ )
+ )
+ & ((sf.col("a") + sf.col("b")).isNotNull())
+ ),
+ self.spark.sql(
+ """
+ select * from l
+ where (a, b) not in (select c, d from t) and (a + b) is not null
+ """
+ ),
+ )
+
+ with self.subTest("same column in subquery"):
+ assertDataFrameEqual(
+ self.spark.table("l")
+ .alias("l1")
+ .where(
+ sf.col("a").isin(
+ self.spark.table("l")
+ .where(sf.col("a") < sf.lit(3))
+ .groupBy(sf.col("a"))
+ .agg({})
+ )
+ )
+ .select(sf.col("a")),
+ self.spark.sql(
+ """select a from l l1 where a in (select a from l where a < 3 group by a)"""
+ ),
+ )
+
+ with self.subTest("col IN (NULL)"):
+ assertDataFrameEqual(
+ self.spark.table("l").where(sf.col("a").isin(None)),
+ self.spark.sql("""SELECT * FROM l WHERE a IN (NULL)"""),
+ )
+
def test_scalar_subquery_with_missing_outer_reference(self):
with self.tempView("l", "r"):
self.df1.createOrReplaceTempView("l")
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index 9577fe3598571..15247b97664d5 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -71,14 +71,14 @@
_make_type_verifier,
_merge_type,
)
-from pyspark.testing.sqlutils import (
- ReusedSQLTestCase,
+from pyspark.testing.objects import (
ExamplePointUDT,
PythonOnlyUDT,
ExamplePoint,
PythonOnlyPoint,
MyObject,
)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import PySparkErrorTestUtils
@@ -1077,7 +1077,7 @@ def test_udf_with_udt(self):
udf = F.udf(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
arrow_udf = F.udf(lambda p: p.y, DoubleType(), useArrow=True)
- self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ self.assertEqual(2.0, df.select(arrow_udf(df.point)).first()[0])
udf2 = F.udf(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index efb6ff159ee54..7b7694072fd46 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -45,12 +45,13 @@
VariantVal,
)
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
+from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
test_compiled,
test_not_compiled_message,
)
-from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.testing.utils import assertDataFrameEqual, timeout
class BaseUDFTestsMixin(object):
@@ -1259,6 +1260,115 @@ def check_err_udf_init(self):
messageParameters={"arg_name": "evalType", "arg_type": "str"},
)
+ def test_timeout_util_with_udf(self):
+ @udf
+ def f(x):
+ time.sleep(10)
+ return str(x)
+
+ @timeout(1)
+ def timeout_func():
+ self.spark.range(1).select(f("id")).show()
+
+ # causing a py4j.protocol.Py4JNetworkError in pyspark classic
+ # causing a TimeoutError in pyspark connect
+ with self.assertRaises(Exception):
+ timeout_func()
+
+ def test_udf_with_udt(self):
+ row = Row(
+ label=1.0,
+ point=ExamplePoint(1.0, 2.0),
+ points=[ExamplePoint(4.0, 5.0), ExamplePoint(6.0, 7.0)],
+ )
+ df = self.spark.createDataFrame([row])
+
+ @udf(returnType=ExamplePointUDT())
+ def doubleInUDTOut(d):
+ return ExamplePoint(d, 10 * d)
+
+ @udf(returnType=DoubleType())
+ def udtInDoubleOut(e):
+ return e.y
+
+ @udf(returnType=ArrayType(ExamplePointUDT()))
+ def doubleInUDTArrayOut(d):
+ return [ExamplePoint(d + i, 10 * d + i) for i in range(2)]
+
+ @udf(returnType=DoubleType())
+ def udtArrayInDoubleOut(es):
+ return es[-1].y
+
+ @udf(returnType=ExamplePointUDT())
+ def udtInUDTOut(e):
+ return ExamplePoint(e.x * 10.0, e.y * 10.0)
+
+ @udf(returnType=DoubleType())
+ def doubleInDoubleOut(d):
+ return d * 100.0
+
+ queries = [
+ (
+ "double -> UDT",
+ df.select(doubleInUDTOut(df.label)),
+ [Row(ExamplePoint(1.0, 10.0))],
+ ),
+ (
+ "UDT -> double",
+ df.select(udtInDoubleOut(df.point)),
+ [Row(2.0)],
+ ),
+ (
+ "double -> array of UDT",
+ df.select(doubleInUDTArrayOut(df.label)),
+ [Row([ExamplePoint(1.0, 10.0), ExamplePoint(2.0, 11.0)])],
+ ),
+ (
+ "array of UDT -> double",
+ df.select(udtArrayInDoubleOut(df.points)),
+ [Row(7.0)],
+ ),
+ (
+ "double -> UDT -> double",
+ df.select(udtInDoubleOut(doubleInUDTOut(df.label))),
+ [Row(10.0)],
+ ),
+ (
+ "double -> UDT -> UDT",
+ df.select(udtInUDTOut(doubleInUDTOut(df.label))),
+ [Row(ExamplePoint(10.0, 100.0))],
+ ),
+ (
+ "double -> double -> UDT",
+ df.select(doubleInUDTOut(doubleInDoubleOut(df.label))),
+ [Row(ExamplePoint(100.0, 1000.0))],
+ ),
+ (
+ "UDT -> UDT -> double",
+ df.select(udtInDoubleOut(udtInUDTOut(df.point))),
+ [Row(20.0)],
+ ),
+ (
+ "UDT -> UDT -> UDT",
+ df.select(udtInUDTOut(udtInUDTOut(df.point))),
+ [Row(ExamplePoint(100.0, 200.0))],
+ ),
+ (
+ "UDT -> double -> double",
+ df.select(doubleInDoubleOut(udtInDoubleOut(df.point))),
+ [Row(200.0)],
+ ),
+ (
+ "UDT -> double -> UDT",
+ df.select(doubleInUDTOut(udtInDoubleOut(df.point))),
+ [Row(ExamplePoint(2.0, 20.0))],
+ ),
+ ]
+
+ for chain, actual, expected in queries:
+ with self.subTest(chain=chain):
+ assertDataFrameEqual(actual=actual, expected=expected)
+
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py
index bafd406d58bc2..9ec977c2552eb 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -1802,6 +1802,38 @@ def test_assert_data_frame_equal_not_support_streaming(self):
self.assertTrue(exception_thrown)
+ def test_assert_schema_equal_with_decimal_types(self):
+ """Test assertSchemaEqual with decimal types of different precision and scale
+ (SPARK-51062)."""
+ from pyspark.sql.types import StructType, StructField, DecimalType
+
+ # Same precision and scale - should pass
+ s1 = StructType(
+ [
+ StructField("price", DecimalType(10, 2), True),
+ ]
+ )
+
+ s1_copy = StructType(
+ [
+ StructField("price", DecimalType(10, 2), True),
+ ]
+ )
+
+ # This should pass
+ assertSchemaEqual(s1, s1_copy)
+
+ # Different precision and scale - should fail
+ s2 = StructType(
+ [
+ StructField("price", DecimalType(12, 4), True),
+ ]
+ )
+
+ # This should fail
+ with self.assertRaises(PySparkAssertionError):
+ assertSchemaEqual(s1, s2)
+
class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
pass
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 79fbf46f005d4..abfd0898fd545 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -117,6 +117,7 @@ def _create_py_udf(
# Note: The values inside the table are generated by `repr`. X' means it throws an exception
# during the conversion.
is_arrow_enabled = False
+
if useArrow is None:
from pyspark.sql import SparkSession
@@ -136,7 +137,7 @@ def _create_py_udf(
except ImportError:
is_arrow_enabled = False
warnings.warn(
- "Arrow optimization failed to enable because PyArrow/pandas is not installed. "
+ "Arrow optimization failed to enable because PyArrow or Pandas is not installed. "
"Falling back to a non-Arrow-optimized UDF.",
RuntimeWarning,
)
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 63beda40dc52d..b0782d04cba3d 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -63,15 +63,6 @@
from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
-has_arrow: bool = False
-try:
- import pyarrow # noqa: F401
-
- has_arrow = True
-except ImportError:
- pass
-
-
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py
index 9247fde78004f..1c926f4980a59 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -273,9 +273,11 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py b/python/pyspark/sql/worker/commit_data_source_write.py
index c891d9f083cb8..d08d65974dfb8 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -119,9 +119,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py
index 33957616c4834..424f070127232 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -184,9 +184,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py b/python/pyspark/sql/worker/data_source_pushdown_filters.py
new file mode 100644
index 0000000000000..0415f450fe0fc
--- /dev/null
+++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py
@@ -0,0 +1,277 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import base64
+import faulthandler
+import json
+import os
+import sys
+import typing
+from dataclasses import dataclass, field
+from typing import IO, Type, Union
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import PySparkAssertionError, PySparkValueError
+from pyspark.errors.exceptions.base import PySparkNotImplementedError
+from pyspark.serializers import SpecialLengths, UTF8Deserializer, read_int, write_int
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceReader,
+ EqualNullSafe,
+ EqualTo,
+ Filter,
+ GreaterThan,
+ GreaterThanOrEqual,
+ In,
+ IsNotNull,
+ IsNull,
+ LessThan,
+ LessThanOrEqual,
+ Not,
+ StringContains,
+ StringEndsWith,
+ StringStartsWith,
+)
+from pyspark.sql.types import StructType, VariantVal, _parse_datatype_json_string
+from pyspark.sql.worker.plan_data_source_read import write_read_func_and_partitions
+from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.worker_util import (
+ check_python_version,
+ pickleSer,
+ read_command,
+ send_accumulator_updates,
+ setup_broadcasts,
+ setup_memory_limits,
+ setup_spark_files,
+)
+
+utf8_deserializer = UTF8Deserializer()
+
+BinaryFilter = Union[
+ EqualTo,
+ EqualNullSafe,
+ GreaterThan,
+ GreaterThanOrEqual,
+ LessThan,
+ LessThanOrEqual,
+ In,
+ StringStartsWith,
+ StringEndsWith,
+ StringContains,
+]
+
+binary_filters = {cls.__name__: cls for cls in typing.get_args(BinaryFilter)}
+
+UnaryFilter = Union[IsNotNull, IsNull]
+
+unary_filters = {cls.__name__: cls for cls in typing.get_args(UnaryFilter)}
+
+
+@dataclass(frozen=True)
+class FilterRef:
+ filter: Filter = field(compare=False)
+ id: int = field(init=False) # only id is used for comparison
+
+ def __post_init__(self) -> None:
+ object.__setattr__(self, "id", id(self.filter))
+
+
+def deserializeVariant(variantDict: dict) -> VariantVal:
+ value = base64.b64decode(variantDict["value"])
+ metadata = base64.b64decode(variantDict["metadata"])
+ return VariantVal(value, metadata)
+
+
+def deserializeFilter(jsonDict: dict) -> Filter:
+ name = jsonDict["name"]
+ filter: Filter
+ if name in binary_filters:
+ binary_filter_cls: Type[BinaryFilter] = binary_filters[name]
+ filter = binary_filter_cls(
+ attribute=tuple(jsonDict["columnPath"]),
+ value=deserializeVariant(jsonDict["value"]).toPython(),
+ )
+ elif name in unary_filters:
+ unary_filter_cls: Type[UnaryFilter] = unary_filters[name]
+ filter = unary_filter_cls(attribute=tuple(jsonDict["columnPath"]))
+ else:
+ raise PySparkNotImplementedError(
+ errorClass="UNSUPPORTED_FILTER",
+ messageParameters={"name": name},
+ )
+ if jsonDict["isNegated"]:
+ filter = Not(filter)
+ return filter
+
+
+def main(infile: IO, outfile: IO) -> None:
+ """
+ Main method for planning a data source read with filter pushdown.
+
+ This process is invoked from the `UserDefinedPythonDataSourceReadRunner.runInPython`
+ method in the optimizer rule `PlanPythonDataSourceScan` in JVM. This process is responsible
+ for creating a `DataSourceReader` object, applying filter pushdown, and sending the
+ information needed back to the JVM.
+
+ The infile and outfile are connected to the JVM via a socket. The JVM sends the following
+ information to this process via the socket:
+ - a `DataSource` instance representing the data source
+ - a `StructType` instance representing the output schema of the data source
+ - a list of filters to be pushed down
+ - configuration values
+
+ This process then creates a `DataSourceReader` instance by calling the `reader` method
+ on the `DataSource` instance. It applies the filters by calling the `pushFilters` method
+ on the reader and determines which filters are supported. The indices of the supported
+ filters are sent back to the JVM, along with the list of partitions and the read function.
+ """
+ faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
+ try:
+ if faulthandler_log_path:
+ faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
+ faulthandler_log_file = open(faulthandler_log_path, "w")
+ faulthandler.enable(file=faulthandler_log_file)
+
+ check_python_version(infile)
+
+ memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
+ setup_memory_limits(memory_limit_mb)
+
+ setup_spark_files(infile)
+ setup_broadcasts(infile)
+
+ _accumulatorRegistry.clear()
+
+ # ----------------------------------------------------------------------
+ # Start of worker logic
+ # ----------------------------------------------------------------------
+
+ # Receive the data source instance.
+ data_source = read_command(pickleSer, infile)
+ if not isinstance(data_source, DataSource):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "a Python data source instance of type 'DataSource'",
+ "actual": f"'{type(data_source).__name__}'",
+ },
+ )
+
+ # Receive the data source output schema.
+ schema_json = utf8_deserializer.loads(infile)
+ schema = _parse_datatype_json_string(schema_json)
+ if not isinstance(schema, StructType):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "an output schema of type 'StructType'",
+ "actual": f"'{type(schema).__name__}'",
+ },
+ )
+
+ # Get the reader.
+ reader = data_source.reader(schema=schema)
+ # Validate the reader.
+ if not isinstance(reader, DataSourceReader):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "an instance of DataSourceReader",
+ "actual": f"'{type(reader).__name__}'",
+ },
+ )
+
+ # Receive the pushdown filters.
+ json_str = utf8_deserializer.loads(infile)
+ filter_dicts = json.loads(json_str)
+ filters = [FilterRef(deserializeFilter(f)) for f in filter_dicts]
+
+ # Push down the filters and get the indices of the unsupported filters.
+ unsupported_filters = set(
+ FilterRef(f) for f in reader.pushFilters([ref.filter for ref in filters])
+ )
+ supported_filter_indices = []
+ for i, filter in enumerate(filters):
+ if filter in unsupported_filters:
+ unsupported_filters.remove(filter)
+ else:
+ supported_filter_indices.append(i)
+
+ # If it returned any filters that are not in the original filters, raise an error.
+ if len(unsupported_filters) > 0:
+ raise PySparkValueError(
+ errorClass="DATA_SOURCE_EXTRANEOUS_FILTERS",
+ messageParameters={
+ "type": type(reader).__name__,
+ "input": str(list(filters)),
+ "extraneous": str(list(unsupported_filters)),
+ },
+ )
+
+ # Receive the max arrow batch size.
+ max_arrow_batch_size = read_int(infile)
+ assert max_arrow_batch_size > 0, (
+ "The maximum arrow batch size should be greater than 0, but got "
+ f"'{max_arrow_batch_size}'"
+ )
+
+ # Return the read function and partitions. Doing this in the same worker as filter pushdown
+ # helps reduce the number of Python worker calls.
+ write_read_func_and_partitions(
+ outfile,
+ reader=reader,
+ data_source=data_source,
+ schema=schema,
+ max_arrow_batch_size=max_arrow_batch_size,
+ )
+
+ # Return the supported filter indices.
+ write_int(len(supported_filter_indices), outfile)
+ for index in supported_filter_indices:
+ write_int(index, outfile)
+
+ # ----------------------------------------------------------------------
+ # End of worker logic
+ # ----------------------------------------------------------------------
+ except BaseException as e:
+ handle_worker_exception(e, outfile)
+ sys.exit(-1)
+ finally:
+ if faulthandler_log_path:
+ faulthandler.disable()
+ faulthandler_log_file.close()
+ os.remove(faulthandler_log_path)
+
+ send_accumulator_updates(outfile)
+
+ # check end of stream
+ if read_int(infile) == SpecialLengths.END_OF_STREAM:
+ write_int(SpecialLengths.END_OF_STREAM, outfile)
+ else:
+ # write a different value to tell JVM to not reuse this worker
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ # Read information about how to connect back to the JVM from the environment.
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
+ main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/lookup_data_sources.py b/python/pyspark/sql/worker/lookup_data_sources.py
index 18737095fa9c6..af138ab689659 100644
--- a/python/pyspark/sql/worker/lookup_data_sources.py
+++ b/python/pyspark/sql/worker/lookup_data_sources.py
@@ -104,9 +104,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py
index 4c6fd4c0a77c3..5edc8185adcfe 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -168,6 +168,101 @@ def batched(iterator: Iterator, n: int) -> Iterator:
yield batch
+def write_read_func_and_partitions(
+ outfile: IO,
+ *,
+ reader: Union[DataSourceReader, DataSourceStreamReader],
+ data_source: DataSource,
+ schema: StructType,
+ max_arrow_batch_size: int,
+) -> None:
+ is_streaming = isinstance(reader, DataSourceStreamReader)
+
+ # Create input converter.
+ converter = ArrowTableToRowsConversion._create_converter(BinaryType())
+
+ # Create output converter.
+ return_type = schema
+
+ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
+ partition_bytes = None
+
+ # Get the partition value from the input iterator.
+ for batch in iterator:
+ # There should be only one row/column in the batch.
+ assert batch.num_columns == 1 and batch.num_rows == 1, (
+ "Expected each batch to have exactly 1 column and 1 row, "
+ f"but found {batch.num_columns} columns and {batch.num_rows} rows."
+ )
+ columns = [column.to_pylist() for column in batch.columns]
+ partition_bytes = converter(columns[0][0])
+
+ assert (
+ partition_bytes is not None
+ ), "The input iterator for Python data source read function is empty."
+
+ # Deserialize the partition value.
+ partition = pickleSer.loads(partition_bytes)
+
+ assert partition is None or isinstance(partition, InputPartition), (
+ "Expected the partition value to be of type 'InputPartition', "
+ f"but found '{type(partition).__name__}'."
+ )
+
+ output_iter = reader.read(partition) # type: ignore[arg-type]
+
+ # Validate the output iterator.
+ if not isinstance(output_iter, Iterator):
+ raise PySparkRuntimeError(
+ errorClass="DATA_SOURCE_INVALID_RETURN_TYPE",
+ messageParameters={
+ "type": type(output_iter).__name__,
+ "name": data_source.name(),
+ "supported_types": "iterator",
+ },
+ )
+
+ return records_to_arrow_batches(output_iter, max_arrow_batch_size, return_type, data_source)
+
+ command = (data_source_read_func, return_type)
+ pickleSer._write_with_length(command, outfile)
+
+ if not is_streaming:
+ # The partitioning of python batch source read is determined before query execution.
+ try:
+ partitions = reader.partitions() # type: ignore[call-arg]
+ if not isinstance(partitions, list):
+ raise PySparkRuntimeError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "'partitions' to return a list",
+ "actual": f"'{type(partitions).__name__}'",
+ },
+ )
+ if not all(isinstance(p, InputPartition) for p in partitions):
+ partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
+ raise PySparkRuntimeError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "elements in 'partitions' to be of type 'InputPartition'",
+ "actual": partition_types,
+ },
+ )
+ if len(partitions) == 0:
+ partitions = [None] # type: ignore[list-item]
+ except NotImplementedError:
+ partitions = [None] # type: ignore[list-item]
+
+ # Return the serialized partition values.
+ write_int(len(partitions), outfile)
+ for partition in partitions:
+ pickleSer._write_with_length(partition, outfile)
+ else:
+ # Send an empty list of partition for stream reader because partitions are planned
+ # in each microbatch during query execution.
+ write_int(0, outfile)
+
+
def main(infile: IO, outfile: IO) -> None:
"""
Main method for planning a data source read.
@@ -250,6 +345,7 @@ def main(infile: IO, outfile: IO) -> None:
"The maximum arrow batch size should be greater than 0, but got "
f"'{max_arrow_batch_size}'"
)
+ enable_pushdown = read_bool(infile)
is_streaming = read_bool(infile)
@@ -269,92 +365,28 @@ def main(infile: IO, outfile: IO) -> None:
"actual": f"'{type(reader).__name__}'",
},
)
-
- # Create input converter.
- converter = ArrowTableToRowsConversion._create_converter(BinaryType())
-
- # Create output converter.
- return_type = schema
-
- def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
- partition_bytes = None
-
- # Get the partition value from the input iterator.
- for batch in iterator:
- # There should be only one row/column in the batch.
- assert batch.num_columns == 1 and batch.num_rows == 1, (
- "Expected each batch to have exactly 1 column and 1 row, "
- f"but found {batch.num_columns} columns and {batch.num_rows} rows."
- )
- columns = [column.to_pylist() for column in batch.columns]
- partition_bytes = converter(columns[0][0])
-
- assert (
- partition_bytes is not None
- ), "The input iterator for Python data source read function is empty."
-
- # Deserialize the partition value.
- partition = pickleSer.loads(partition_bytes)
-
- assert partition is None or isinstance(partition, InputPartition), (
- "Expected the partition value to be of type 'InputPartition', "
- f"but found '{type(partition).__name__}'."
+ is_pushdown_implemented = (
+ getattr(reader.pushFilters, "__func__", None) is not DataSourceReader.pushFilters
)
-
- output_iter = reader.read(partition) # type: ignore[arg-type]
-
- # Validate the output iterator.
- if not isinstance(output_iter, Iterator):
- raise PySparkRuntimeError(
- errorClass="DATA_SOURCE_INVALID_RETURN_TYPE",
+ if is_pushdown_implemented and not enable_pushdown:
+ # Do not silently ignore pushFilters when pushdown is disabled.
+ # Raise an error to ask the user to enable pushdown.
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_PUSHDOWN_DISABLED",
messageParameters={
- "type": type(output_iter).__name__,
- "name": data_source.name(),
- "supported_types": "iterator",
+ "type": type(reader).__name__,
+ "conf": "spark.sql.python.filterPushdown.enabled",
},
)
- return records_to_arrow_batches(
- output_iter, max_arrow_batch_size, return_type, data_source
- )
-
- command = (data_source_read_func, return_type)
- pickleSer._write_with_length(command, outfile)
-
- if not is_streaming:
- # The partitioning of python batch source read is determined before query execution.
- try:
- partitions = reader.partitions() # type: ignore[call-arg]
- if not isinstance(partitions, list):
- raise PySparkRuntimeError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "'partitions' to return a list",
- "actual": f"'{type(partitions).__name__}'",
- },
- )
- if not all(isinstance(p, InputPartition) for p in partitions):
- partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
- raise PySparkRuntimeError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "elements in 'partitions' to be of type 'InputPartition'",
- "actual": partition_types,
- },
- )
- if len(partitions) == 0:
- partitions = [None] # type: ignore[list-item]
- except NotImplementedError:
- partitions = [None] # type: ignore[list-item]
-
- # Return the serialized partition values.
- write_int(len(partitions), outfile)
- for partition in partitions:
- pickleSer._write_with_length(partition, outfile)
- else:
- # Send an empty list of partition for stream reader because partitions are planned
- # in each microbatch during query execution.
- write_int(0, outfile)
+ # Send the read function and partitions to the JVM.
+ write_read_func_and_partitions(
+ outfile,
+ reader=reader,
+ data_source=data_source,
+ schema=schema,
+ max_arrow_batch_size=max_arrow_batch_size,
+ )
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
@@ -377,9 +409,11 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index c1bf5289cbf89..cf6246b544909 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -21,7 +21,7 @@
from typing import IO
from pyspark.accumulators import _accumulatorRegistry
-from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.errors import PySparkAssertionError
from pyspark.serializers import (
read_bool,
read_int,
@@ -96,44 +96,36 @@ def main(infile: IO, outfile: IO) -> None:
)
# Receive the `overwrite` flag.
overwrite = read_bool(infile)
- # Instantiate data source reader.
- try:
- # Create the data source writer instance.
- writer = data_source.streamWriter(schema=schema, overwrite=overwrite)
-
- # Receive the commit messages.
- num_messages = read_int(infile)
- commit_messages = []
- for _ in range(num_messages):
- message = pickleSer._read_with_length(infile)
- if message is not None and not isinstance(message, WriterCommitMessage):
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "an instance of WriterCommitMessage",
- "actual": f"'{type(message).__name__}'",
- },
- )
- commit_messages.append(message)
-
- batch_id = read_long(infile)
- abort = read_bool(infile)
-
- # Commit or abort the Python data source write.
- # Note the commit messages can be None if there are failed tasks.
- if abort:
- writer.abort(commit_messages, batch_id)
- else:
- writer.commit(commit_messages, batch_id)
- # Send a status code back to JVM.
- write_int(0, outfile)
- outfile.flush()
- except Exception as e:
- error_msg = "data source {} throw exception: {}".format(data_source.name, e)
- raise PySparkRuntimeError(
- errorClass="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
- messageParameters={"action": "commitOrAbort", "error": error_msg},
- )
+ # Create the data source writer instance.
+ writer = data_source.streamWriter(schema=schema, overwrite=overwrite)
+ # Receive the commit messages.
+ num_messages = read_int(infile)
+
+ commit_messages = []
+ for _ in range(num_messages):
+ message = pickleSer._read_with_length(infile)
+ if message is not None and not isinstance(message, WriterCommitMessage):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "an instance of WriterCommitMessage",
+ "actual": f"'{type(message).__name__}'",
+ },
+ )
+ commit_messages.append(message)
+
+ batch_id = read_long(infile)
+ abort = read_bool(infile)
+
+ # Commit or abort the Python data source write.
+ # Note the commit messages can be None if there are failed tasks.
+ if abort:
+ writer.abort(commit_messages, batch_id)
+ else:
+ writer.commit(commit_messages, batch_id)
+ # Send a status code back to JVM.
+ write_int(0, outfile)
+ outfile.flush()
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
@@ -156,9 +148,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py
index 235e5c249f691..d6d055f01e543 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -255,9 +255,11 @@ def batch_to_rows() -> Iterator[Row]:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index 9785664d7a15a..957f9d70687b8 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -262,8 +262,8 @@ def resources(self) -> Dict[str, "ResourceInformation"]:
def _load_from_socket(
- port: Optional[Union[str, int]],
- auth_secret: str,
+ conn_info: Optional[Union[str, int]],
+ auth_secret: Optional[str],
function: int,
all_gather_message: Optional[str] = None,
) -> List[str]:
@@ -271,7 +271,7 @@ def _load_from_socket(
Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed.
"""
- (sockfile, sock) = local_connect_and_auth(port, auth_secret)
+ (sockfile, sock) = local_connect_and_auth(conn_info, auth_secret)
# The call may block forever, so no timeout
sock.settimeout(None)
@@ -331,7 +331,7 @@ class BarrierTaskContext(TaskContext):
[1]
"""
- _port: ClassVar[Optional[Union[str, int]]] = None
+ _conn_info: ClassVar[Optional[Union[str, int]]] = None
_secret: ClassVar[Optional[str]] = None
@classmethod
@@ -368,13 +368,13 @@ def get(cls: Type["BarrierTaskContext"]) -> "BarrierTaskContext":
@classmethod
def _initialize(
- cls: Type["BarrierTaskContext"], port: Optional[Union[str, int]], secret: str
+ cls: Type["BarrierTaskContext"], conn_info: Optional[Union[str, int]], secret: Optional[str]
) -> None:
"""
Initialize :class:`BarrierTaskContext`, other methods within :class:`BarrierTaskContext`
can only be called after BarrierTaskContext is initialized.
"""
- cls._port = port
+ cls._conn_info = conn_info
cls._secret = secret
def barrier(self) -> None:
@@ -393,7 +393,7 @@ def barrier(self) -> None:
calls, in all possible code branches. Otherwise, you may get the job hanging
or a `SparkException` after timeout.
"""
- if self._port is None or self._secret is None:
+ if self._conn_info is None:
raise PySparkRuntimeError(
errorClass="CALL_BEFORE_INITIALIZE",
messageParameters={
@@ -402,7 +402,7 @@ def barrier(self) -> None:
},
)
else:
- _load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
+ _load_from_socket(self._conn_info, self._secret, BARRIER_FUNCTION)
def allGather(self, message: str = "") -> List[str]:
"""
@@ -422,7 +422,7 @@ def allGather(self, message: str = "") -> List[str]:
"""
if not isinstance(message, str):
raise TypeError("Argument `message` must be of type `str`")
- elif self._port is None or self._secret is None:
+ elif self._conn_info is None:
raise PySparkRuntimeError(
errorClass="CALL_BEFORE_INITIALIZE",
messageParameters={
@@ -431,7 +431,7 @@ def allGather(self, message: str = "") -> List[str]:
},
)
else:
- return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message)
+ return _load_from_socket(self._conn_info, self._secret, ALL_GATHER_FUNCTION, message)
def getTaskInfos(self) -> List["BarrierTaskInfo"]:
"""
@@ -453,7 +453,7 @@ def getTaskInfos(self) -> List["BarrierTaskInfo"]:
>>> barrier_info.address
'...:...'
"""
- if self._port is None or self._secret is None:
+ if self._conn_info is None:
raise PySparkRuntimeError(
errorClass="CALL_BEFORE_INITIALIZE",
messageParameters={
diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py
index 423a717e8ab5e..c0d91fb8bd149 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -158,6 +158,9 @@ def conf(cls):
# Set a static token for all tests so the parallelism doesn't overwrite each
# tests' environment variables
conf.set("spark.connect.authenticate.token", "deadbeef")
+ # Make the max size of ML Cache larger, to avoid CONNECT_ML.CACHE_INVALID issues
+ # in tests.
+ conf.set("spark.connect.session.connectML.mlCache.maxSize", "1g")
return conf
@classmethod
@@ -185,6 +188,10 @@ def tearDownClass(cls):
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
cls.spark.stop()
+ def setUp(self) -> None:
+ # force to clean up the ML cache before each test
+ self.spark.client._cleanup_ml_cache()
+
def test_assert_remote_mode(self):
from pyspark.sql import is_remote
@@ -197,3 +204,55 @@ def quiet(self):
return QuietTest(self._legacy_sc)
else:
return contextlib.nullcontext()
+
+
+@unittest.skipIf(
+ not should_test_connect or is_remote_only(),
+ connect_requirement_message or "Requires JVM access",
+)
+class ReusedMixedTestCase(ReusedConnectTestCase, SQLTestUtils):
+ @classmethod
+ def setUpClass(cls):
+ super(ReusedMixedTestCase, cls).setUpClass()
+ # Disable the shared namespace so pyspark.sql.functions, etc point the regular
+ # PySpark libraries.
+ os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
+
+ cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
+ cls.spark = PySparkSession._instantiatedSession
+ assert cls.spark is not None
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ # Stopping Spark Connect closes the session in JVM at the server.
+ cls.spark = cls.connect
+ del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
+ finally:
+ super(ReusedMixedTestCase, cls).tearDownClass()
+
+ def setUp(self) -> None:
+ # force to clean up the ML cache before each test
+ self.connect.client._cleanup_ml_cache()
+
+ def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
+ from pyspark.sql.classic.dataframe import DataFrame as SDF
+ from pyspark.sql.connect.dataframe import DataFrame as CDF
+
+ assert isinstance(df1, (SDF, CDF))
+ if isinstance(df1, SDF):
+ str1 = df1._jdf.showString(n, truncate, False)
+ else:
+ str1 = df1._show_string(n, truncate, False)
+
+ assert isinstance(df2, (SDF, CDF))
+ if isinstance(df2, SDF):
+ str2 = df2._jdf.showString(n, truncate, False)
+ else:
+ str2 = df2._show_string(n, truncate, False)
+
+ self.assertEqual(str1, str2)
+
+ def test_assert_remote_mode(self):
+ # no need to test this in mixed mode
+ pass
diff --git a/python/pyspark/testing/objects.py b/python/pyspark/testing/objects.py
new file mode 100644
index 0000000000000..5b97664afbddc
--- /dev/null
+++ b/python/pyspark/testing/objects.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import datetime
+
+from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType
+
+
+class UTCOffsetTimezone(datetime.tzinfo):
+ """
+ Specifies timezone in UTC offset
+ """
+
+ def __init__(self, offset=0):
+ self.ZERO = datetime.timedelta(hours=offset)
+
+ def utcoffset(self, dt):
+ return self.ZERO
+
+ def dst(self, dt):
+ return self.ZERO
+
+
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return "pyspark.sql.tests"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.sql.test.ExamplePointUDT"
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and other.x == self.x and other.y == self.y
+
+
+class PythonOnlyUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return "__main__"
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return PythonOnlyPoint(datum[0], datum[1])
+
+ @staticmethod
+ def foo():
+ pass
+
+ @property
+ def props(self):
+ return {}
+
+
+class PythonOnlyPoint(ExamplePoint):
+ """
+ An example class to demonstrate UDT in only Python
+ """
+
+ __UDT__ = PythonOnlyUDT() # type: ignore
+
+
+class MyObject:
+ def __init__(self, key, value):
+ self.key = key
+ self.value = value
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index 4151dfd90459f..98d04e7d5b1af 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -16,7 +16,6 @@
#
import glob
-import datetime
import math
import os
import shutil
@@ -24,7 +23,7 @@
from contextlib import contextmanager
from pyspark.sql import SparkSession
-from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
+from pyspark.sql.types import Row
from pyspark.testing.utils import (
ReusedPySparkTestCase,
PySparkErrorTestUtils,
@@ -75,108 +74,6 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
test_compiled = test_not_compiled_message is None
-class UTCOffsetTimezone(datetime.tzinfo):
- """
- Specifies timezone in UTC offset
- """
-
- def __init__(self, offset=0):
- self.ZERO = datetime.timedelta(hours=offset)
-
- def utcoffset(self, dt):
- return self.ZERO
-
- def dst(self, dt):
- return self.ZERO
-
-
-class ExamplePointUDT(UserDefinedType):
- """
- User-defined type (UDT) for ExamplePoint.
- """
-
- @classmethod
- def sqlType(cls):
- return ArrayType(DoubleType(), False)
-
- @classmethod
- def module(cls):
- return "pyspark.sql.tests"
-
- @classmethod
- def scalaUDT(cls):
- return "org.apache.spark.sql.test.ExamplePointUDT"
-
- def serialize(self, obj):
- return [obj.x, obj.y]
-
- def deserialize(self, datum):
- return ExamplePoint(datum[0], datum[1])
-
-
-class ExamplePoint:
- """
- An example class to demonstrate UDT in Scala, Java, and Python.
- """
-
- __UDT__ = ExamplePointUDT()
-
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def __repr__(self):
- return "ExamplePoint(%s,%s)" % (self.x, self.y)
-
- def __str__(self):
- return "(%s,%s)" % (self.x, self.y)
-
- def __eq__(self, other):
- return isinstance(other, self.__class__) and other.x == self.x and other.y == self.y
-
-
-class PythonOnlyUDT(UserDefinedType):
- """
- User-defined type (UDT) for ExamplePoint.
- """
-
- @classmethod
- def sqlType(cls):
- return ArrayType(DoubleType(), False)
-
- @classmethod
- def module(cls):
- return "__main__"
-
- def serialize(self, obj):
- return [obj.x, obj.y]
-
- def deserialize(self, datum):
- return PythonOnlyPoint(datum[0], datum[1])
-
- @staticmethod
- def foo():
- pass
-
- @property
- def props(self):
- return {}
-
-
-class PythonOnlyPoint(ExamplePoint):
- """
- An example class to demonstrate UDT in only Python
- """
-
- __UDT__ = PythonOnlyUDT() # type: ignore
-
-
-class MyObject:
- def __init__(self, key, value):
- self.key = key
- self.value = value
-
-
class SQLTestUtils:
"""
This util assumes the instance of this to have 'spark' attribute, having a spark session.
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 780f0f90ac62a..d38957ca40262 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -23,6 +23,7 @@
import functools
from decimal import Decimal
from time import time, sleep
+import signal
from typing import (
Any,
Optional,
@@ -122,6 +123,26 @@ def write_int(i):
return struct.pack("!i", i)
+def timeout(seconds):
+ def decorator(func):
+ def handler(signum, frame):
+ raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds")
+
+ def wrapper(*args, **kwargs):
+ signal.alarm(0)
+ signal.signal(signal.SIGALRM, handler)
+ signal.alarm(seconds)
+ try:
+ result = func(*args, **kwargs)
+ finally:
+ signal.alarm(0)
+ return result
+
+ return wrapper
+
+ return decorator
+
+
def eventually(
timeout=30.0,
catch_assertions=False,
@@ -225,7 +246,11 @@ def conf(cls):
def setUpClass(cls):
from pyspark import SparkContext
- cls.sc = SparkContext("local[4]", cls.__name__, conf=cls.conf())
+ cls.sc = SparkContext(cls.master(), cls.__name__, conf=cls.conf())
+
+ @classmethod
+ def master(cls):
+ return "local[4]"
@classmethod
def tearDownClass(cls):
@@ -519,6 +544,9 @@ def compare_datatypes_ignore_nullable(dt1: Any, dt2: Any):
if dt1.typeName() == dt2.typeName():
if dt1.typeName() == "array":
return compare_datatypes_ignore_nullable(dt1.elementType, dt2.elementType)
+ elif dt1.typeName() == "decimal":
+ # Fix for SPARK-51062: Compare precision and scale for decimal types
+ return dt1.precision == dt2.precision and dt1.scale == dt2.scale
elif dt1.typeName() == "struct":
return compare_schemas_ignore_nullable(dt1, dt2)
else:
diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py
index 5f2c8b49d279d..909ed0447154b 100644
--- a/python/pyspark/tests/test_appsubmit.py
+++ b/python/pyspark/tests/test_appsubmit.py
@@ -36,6 +36,8 @@ def setUp(self):
"spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
"--conf",
"spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ "--conf",
+ "spark.python.unix.domain.socket.enabled=false",
]
def tearDown(self):
diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py
index e1079ca7b4e89..d9bda1e569933 100644
--- a/python/pyspark/tests/test_util.py
+++ b/python/pyspark/tests/test_util.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
import os
+import time
import unittest
from unittest.mock import patch
@@ -23,7 +24,7 @@
from pyspark import keyword_only
from pyspark.util import _parse_memory
from pyspark.loose_version import LooseVersion
-from pyspark.testing.utils import PySparkTestCase, eventually
+from pyspark.testing.utils import PySparkTestCase, eventually, timeout
from pyspark.find_spark_home import _find_spark_home
@@ -87,6 +88,28 @@ def test_find_spark_home(self):
finally:
os.environ["SPARK_HOME"] = origin
+ def test_timeout_decorator(self):
+ @timeout(1)
+ def timeout_func():
+ time.sleep(10)
+
+ with self.assertRaises(TimeoutError) as e:
+ timeout_func()
+ self.assertTrue("Function timeout_func timed out after 1 seconds" in str(e.exception))
+
+ def test_timeout_function(self):
+ def timeout_func():
+ time.sleep(10)
+
+ with self.assertRaises(TimeoutError) as e:
+ timeout(1)(timeout_func)()
+ self.assertTrue("Function timeout_func timed out after 1 seconds" in str(e.exception))
+
+ def test_timeout_lambda(self):
+ with self.assertRaises(TimeoutError) as e:
+ timeout(1)(lambda: time.sleep(10))()
+ self.assertTrue("Function timed out after 1 seconds" in str(e.exception))
+
@eventually(timeout=180, catch_assertions=True)
def test_eventually_decorator(self):
import random
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 5a5a8d31e77dc..7e07d95538e4a 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -63,6 +63,8 @@
ArrowCogroupedMapUDFType,
PandasGroupedMapUDFTransformWithStateType,
PandasGroupedMapUDFTransformWithStateInitStateType,
+ GroupedMapUDFTransformWithStateType,
+ GroupedMapUDFTransformWithStateInitStateType,
)
from pyspark.sql._typing import (
SQLArrowBatchedUDFType,
@@ -395,6 +397,12 @@ def outer(ff: Callable) -> Callable:
@functools.wraps(ff)
def inner(*args: Any, **kwargs: Any) -> Any:
+ # Propagates the active remote spark session to the current thread.
+ from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
+
+ RemoteSparkSession._set_default_and_active_session(
+ session # type: ignore[arg-type]
+ )
# Set thread locals in child thread.
for attr, value in session_client_thread_local_attrs:
setattr(
@@ -633,6 +641,10 @@ class PythonEvalType:
SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: "PandasGroupedMapUDFTransformWithStateInitStateType" = ( # noqa: E501
212
)
+ SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: "GroupedMapUDFTransformWithStateType" = 213
+ SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: "GroupedMapUDFTransformWithStateInitStateType" = ( # noqa: E501
+ 214
+ )
SQL_TABLE_UDF: "SQLTableUDFType" = 300
SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301
@@ -652,9 +664,9 @@ def _create_local_socket(sock_info: "JavaArray") -> "io.BufferedRWPair":
"""
sockfile: "io.BufferedRWPair"
sock: "socket.socket"
- port: int = sock_info[0]
+ conn_info: int = sock_info[0]
auth_secret: str = sock_info[1]
- sockfile, sock = local_connect_and_auth(port, auth_secret)
+ sockfile, sock = local_connect_and_auth(conn_info, auth_secret)
# The RDD materialization time is unpredictable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
@@ -731,7 +743,9 @@ def __del__(self) -> None:
return iter(PyLocalIterable(sock_info, serializer))
-def local_connect_and_auth(port: Optional[Union[str, int]], auth_secret: str) -> Tuple:
+def local_connect_and_auth(
+ conn_info: Optional[Union[str, int]], auth_secret: Optional[str]
+) -> Tuple:
"""
Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
Handles IPV4 & IPV6, does some error handling.
@@ -739,26 +753,49 @@ def local_connect_and_auth(port: Optional[Union[str, int]], auth_secret: str) ->
Parameters
----------
port : str or int, optional
- auth_secret : str
+ auth_secret : str, optional
Returns
-------
tuple
with (sockfile, sock)
"""
+ is_unix_domain_socket = isinstance(conn_info, str) and auth_secret is None
+ if is_unix_domain_socket:
+ sock_path = conn_info
+ assert isinstance(sock_path, str)
+ sock = None
+ try:
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15)))
+ sock.connect(sock_path)
+ sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+ return (sockfile, sock)
+ except socket.error as e:
+ if sock is not None:
+ sock.close()
+ raise PySparkRuntimeError(
+ errorClass="CANNOT_OPEN_SOCKET",
+ messageParameters={
+ "errors": "tried to connect to %s, but an error occurred: %s"
+ % (sock_path, str(e)),
+ },
+ )
+
sock = None
errors = []
# Support for both IPv4 and IPv6.
addr = "127.0.0.1"
if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
addr = "::1"
- for res in socket.getaddrinfo(addr, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
+ for res in socket.getaddrinfo(addr, conn_info, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, _, sa = res
try:
sock = socket.socket(af, socktype, proto)
sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15)))
sock.connect(sa)
sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+ assert isinstance(auth_secret, str)
_do_server_auth(sockfile, auth_secret)
return (sockfile, sock)
except socket.error as e:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 29dfd65c0e2b8..5f4408851c671 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -34,7 +34,7 @@
_deserialize_accumulator,
)
from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
-from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
+from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPySparkFuncMode
from pyspark.taskcontext import BarrierTaskContext, TaskContext
from pyspark.resource import ResourceInformation
from pyspark.util import PythonEvalType, local_connect_and_auth
@@ -45,7 +45,6 @@
write_long,
read_int,
SpecialLengths,
- UTF8Deserializer,
CPickleSerializer,
BatchedSerializer,
)
@@ -60,6 +59,8 @@
ApplyInPandasWithStateSerializer,
TransformWithStateInPandasSerializer,
TransformWithStateInPandasInitStateSerializer,
+ TransformWithStateInPySparkRowSerializer,
+ TransformWithStateInPySparkRowInitStateSerializer,
)
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import (
@@ -141,6 +142,7 @@ def verify_result_length(result, length):
raise PySparkRuntimeError(
errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
messageParameters={
+ "udf_type": "pandas_udf",
"expected": str(length),
"actual": str(len(result)),
},
@@ -213,6 +215,7 @@ def verify_result_length(result, length):
raise PySparkRuntimeError(
errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
messageParameters={
+ "udf_type": "arrow_batch_udf",
"expected": str(length),
"actual": str(len(result)),
},
@@ -221,7 +224,7 @@ def verify_result_length(result, length):
return (
args_kwargs_offsets,
- lambda *a: (verify_result_length(evaluate(*a), len(a[0])), arrow_return_type),
+ lambda *a: (verify_result_length(evaluate(*a), len(a[0])), arrow_return_type, return_type),
)
@@ -569,6 +572,39 @@ def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)]
+def wrap_grouped_transform_with_state_udf(f, return_type, runner_conf):
+ def wrapped(stateful_processor_api_client, mode, key, values):
+ result_iter = f(stateful_processor_api_client, mode, key, values)
+
+ # TODO(SPARK-XXXXX): add verification that elements in result_iter are
+ # indeed of type Row and confirm to assigned cols
+
+ return result_iter
+
+ arrow_return_type = to_arrow_type(return_type, use_large_var_types(runner_conf))
+ return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)]
+
+
+def wrap_grouped_transform_with_state_init_state_udf(f, return_type, runner_conf):
+ def wrapped(stateful_processor_api_client, mode, key, values):
+ if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
+ values_gen = values[0]
+ init_states_gen = values[1]
+ else:
+ values_gen = iter([])
+ init_states_gen = iter([])
+
+ result_iter = f(stateful_processor_api_client, mode, key, values_gen, init_states_gen)
+
+ # TODO(SPARK-XXXXX): add verification that elements in result_iter are
+ # indeed of type pd.DataFrame and confirm to assigned cols
+
+ return result_iter
+
+ arrow_return_type = to_arrow_type(return_type, use_large_var_types(runner_conf))
+ return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)]
+
+
def wrap_grouped_map_pandas_udf_with_state(f, return_type, runner_conf):
"""
Provides a new lambda instance wrapping user function of applyInPandasWithState.
@@ -933,6 +969,12 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
return args_offsets, wrap_grouped_transform_with_state_pandas_init_state_udf(
func, return_type, runner_conf
)
+ elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:
+ return args_offsets, wrap_grouped_transform_with_state_udf(func, return_type, runner_conf)
+ elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF:
+ return args_offsets, wrap_grouped_transform_with_state_init_state_udf(
+ func, return_type, runner_conf
+ )
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf)
@@ -1530,6 +1572,8 @@ def read_udfs(pickleSer, infile, eval_type):
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
):
# Load conf used for pandas_udf evaluation
num_conf = read_int(infile)
@@ -1544,8 +1588,12 @@ def read_udfs(pickleSer, infile, eval_type):
elif (
eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
+ or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF
+ or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF
):
state_server_port = read_int(infile)
+ if state_server_port == -1:
+ state_server_port = utf8_deserializer.loads(infile)
key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
@@ -1593,6 +1641,20 @@ def read_udfs(pickleSer, infile, eval_type):
ser = TransformWithStateInPandasInitStateSerializer(
timezone, safecheck, _assign_cols_by_name, arrow_max_records_per_batch
)
+ elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:
+ arrow_max_records_per_batch = runner_conf.get(
+ "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
+ )
+ arrow_max_records_per_batch = int(arrow_max_records_per_batch)
+
+ ser = TransformWithStateInPySparkRowSerializer(arrow_max_records_per_batch)
+ elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF:
+ arrow_max_records_per_batch = runner_conf.get(
+ "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
+ )
+ arrow_max_records_per_batch = int(arrow_max_records_per_batch)
+
+ ser = TransformWithStateInPySparkRowInitStateSerializer(arrow_max_records_per_batch)
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
ser = ArrowStreamUDFSerializer()
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
@@ -1612,6 +1674,13 @@ def read_udfs(pickleSer, infile, eval_type):
ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
# Arrow-optimized Python UDF uses explicit Arrow cast for type coercion
arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
+ # Arrow-optimized Python UDF takes input types
+ input_types = (
+ [f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile))]
+ if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
+ else None
+ )
+
ser = ArrowStreamPandasUDFSerializer(
timezone,
safecheck,
@@ -1620,6 +1689,7 @@ def read_udfs(pickleSer, infile, eval_type):
struct_in_pandas,
ndarray_as_list,
arrow_cast,
+ input_types,
)
else:
batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100"))
@@ -1767,7 +1837,7 @@ def mapper(a):
def mapper(a):
mode = a[0]
- if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+ if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
key = a[1]
def values_gen():
@@ -1804,7 +1874,7 @@ def values_gen():
def mapper(a):
mode = a[0]
- if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+ if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
key = a[1]
def values_gen():
@@ -1819,6 +1889,66 @@ def values_gen():
# mode == PROCESS_TIMER or mode == COMPLETE
return f(stateful_processor_api_client, mode, None, iter([]))
+ elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:
+ # We assume there is only one UDF here because grouped map doesn't
+ # support combining multiple UDFs.
+ assert num_udfs == 1
+
+ # See TransformWithStateInPySparkExec for how arg_offsets are used to
+ # distinguish between grouping attributes and data attributes
+ arg_offsets, f = read_single_udf(
+ pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
+ )
+ parsed_offsets = extract_key_value_indexes(arg_offsets)
+ ser.key_offsets = parsed_offsets[0][0]
+ stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
+
+ def mapper(a):
+ mode = a[0]
+
+ if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
+ key = a[1]
+ values = a[2]
+
+ # This must be generator comprehension - do not materialize.
+ return f(stateful_processor_api_client, mode, key, values)
+ else:
+ # mode == PROCESS_TIMER or mode == COMPLETE
+ return f(stateful_processor_api_client, mode, None, iter([]))
+
+ elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF:
+ # We assume there is only one UDF here because grouped map doesn't
+ # support combining multiple UDFs.
+ assert num_udfs == 1
+
+ # See TransformWithStateInPandasExec for how arg_offsets are used to
+ # distinguish between grouping attributes and data attributes
+ arg_offsets, f = read_single_udf(
+ pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
+ )
+ # parsed offsets:
+ # [
+ # [groupingKeyOffsets, dedupDataOffsets],
+ # [initStateGroupingOffsets, dedupInitDataOffsets]
+ # ]
+ parsed_offsets = extract_key_value_indexes(arg_offsets)
+ ser.key_offsets = parsed_offsets[0][0]
+ ser.init_key_offsets = parsed_offsets[1][0]
+ stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
+
+ def mapper(a):
+ mode = a[0]
+
+ if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
+ key = a[1]
+ values = a[2]
+
+ # This must be generator comprehension - do not materialize.
+ return f(stateful_processor_api_client, mode, key, values)
+ else:
+ # mode == PROCESS_TIMER or mode == COMPLETE
+ return f(stateful_processor_api_client, mode, None, iter([]))
+
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
import pyarrow as pa
@@ -1973,8 +2103,6 @@ def main(infile, outfile):
# read inputs only for a barrier task
isBarrier = read_bool(infile)
- boundPort = read_int(infile)
- secret = UTF8Deserializer().loads(infile)
memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
@@ -1982,6 +2110,12 @@ def main(infile, outfile):
# initialize global state
taskContext = None
if isBarrier:
+ boundPort = read_int(infile)
+ secret = None
+ if boundPort == -1:
+ boundPort = utf8_deserializer.loads(infile)
+ else:
+ secret = utf8_deserializer.loads(infile)
taskContext = BarrierTaskContext._getOrCreate()
BarrierTaskContext._initialize(boundPort, secret)
# Set the task context instance here, so we can get it by TaskContext.get for
@@ -2075,9 +2209,11 @@ def process():
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH", int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
index 5c758d3f83fe6..c2f35db8d52de 100644
--- a/python/pyspark/worker_util.py
+++ b/python/pyspark/worker_util.py
@@ -156,9 +156,13 @@ def setup_broadcasts(infile: IO) -> None:
num_broadcast_variables = read_int(infile)
if needs_broadcast_decryption_server:
# read the decrypted data from a server in the jvm
- port = read_int(infile)
- auth_secret = utf8_deserializer.loads(infile)
- (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+ conn_info = read_int(infile)
+ auth_secret = None
+ if conn_info == -1:
+ conn_info = utf8_deserializer.loads(infile)
+ else:
+ auth_secret = utf8_deserializer.loads(infile)
+ (broadcast_sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
diff --git a/python/run-tests.py b/python/run-tests.py
index 64ac48e210db4..091fcfe73ac10 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -111,19 +111,24 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_
while os.path.isdir(tmp_dir):
tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
os.mkdir(tmp_dir)
+ sock_dir = os.getenv('TMPDIR') or os.getenv('TEMP') or os.getenv('TMP') or '/tmp'
env["TMPDIR"] = tmp_dir
metastore_dir = os.path.join(tmp_dir, str(uuid.uuid4()))
while os.path.isdir(metastore_dir):
metastore_dir = os.path.join(metastore_dir, str(uuid.uuid4()))
os.mkdir(metastore_dir)
- # Also override the JVM's temp directory by setting driver and executor options.
- java_options = "-Djava.io.tmpdir={0}".format(tmp_dir)
+ # Also override the JVM's temp directory and log4j conf by setting driver and executor options.
+ log4j2_path = os.path.join(SPARK_HOME, "python/test_support/log4j2.properties")
+ java_options = "-Djava.io.tmpdir={0} -Dlog4j.configurationFile={1}".format(
+ tmp_dir, log4j2_path
+ )
java_options = java_options + " -Xss4M"
spark_args = [
"--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.sql.warehouse.dir='{0}'".format(metastore_dir),
+ "--conf", "spark.python.unix.domain.socket.dir={0}".format(sock_dir),
"pyspark-shell",
]
diff --git a/python/test_support/log4j2.properties b/python/test_support/log4j2.properties
new file mode 100644
index 0000000000000..6629658c1d28c
--- /dev/null
+++ b/python/test_support/log4j2.properties
@@ -0,0 +1,31 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+rootLogger.level = info
+rootLogger.appenderRef.file.ref = File
+
+appender.file.type = File
+appender.file.name = File
+appender.file.fileName = target/unit-tests.log
+appender.file.append = true
+appender.file.layout.type = PatternLayout
+appender.file.layout.pattern = %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex
+
+# Silence verbose logs from 3rd-party libraries.
+logger.netty.name = io.netty
+logger.netty.level = info
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 8969bc8b5e2b9..9bf716e52fffe 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -400,6 +400,7 @@ class ReplSuite extends SparkFunSuite {
test("register UDF via SparkSession.addArtifact") {
val artifactPath = new File("src/test/resources").toPath
val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
+ assume(intSumUdfPath.toFile.exists)
val output = runInterpreterInPasteMode("local",
s"""
|import org.apache.spark.sql.api.java.UDF2
@@ -438,6 +439,7 @@ class ReplSuite extends SparkFunSuite {
test("register a class via SparkSession.addArtifact") {
val artifactPath = new File("src/test/resources").toPath
val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
+ assume(intSumUdfPath.toFile.exists)
val output = runInterpreterInPasteMode("local",
s"""
|import org.apache.spark.sql.functions.udf
diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md
index 4b0ec2cd852fa..3511a64f7fe5b 100644
--- a/resource-managers/kubernetes/integration-tests/README.md
+++ b/resource-managers/kubernetes/integration-tests/README.md
@@ -199,9 +199,9 @@ to the wrapper scripts and using the wrapper scripts will simply set these appro
spark.kubernetes.test.javaImageTag
- A specific OpenJDK base image tag to use, when set uses it instead of azul/zulu-openjdk.
+ A specific Azul Zulu OpenJDK base image tag to use, when set uses it instead of 21.
-
azul/zulu-openjdk
+
N/A
spark.kubernetes.test.imageTagFile
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala
index e25bd665dec0d..aabf35d66e6bf 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala
@@ -153,7 +153,7 @@ class AmIpFilterSuite extends SparkFunSuite {
// change proxy configurations
params = new util.HashMap[String, String]
- params.put(AmIpFilter.PROXY_HOSTS, "unknownhost")
+ params.put(AmIpFilter.PROXY_HOSTS, "unknownhostaf79d34c")
params.put(AmIpFilter.PROXY_URI_BASES, proxyUri)
conf = new DummyFilterConfig(params)
filter.init(conf)
@@ -162,8 +162,10 @@ class AmIpFilterSuite extends SparkFunSuite {
assert(!filter.getProxyAddresses.isEmpty)
// waiting for configuration update
eventually(timeout(5.seconds), interval(500.millis)) {
- assertThrows[ServletException] {
- filter.getProxyAddresses.isEmpty
+ try {
+ assert(filter.getProxyAddresses.isEmpty)
+ } catch {
+ case e: ServletException => // do nothing
}
}
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 4408817b0426d..b3a792bbfc739 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -40,6 +40,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.internal.config.UI._
import org.apache.spark.launcher._
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded}
@@ -268,11 +269,19 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
}
test("run Python application in yarn-client mode") {
- testPySpark(true)
+ testPySpark(
+ true,
+ // User is unknown in this suite.
+ extraConf = Map(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString)
+ )
}
test("run Python application in yarn-cluster mode") {
- testPySpark(false)
+ testPySpark(
+ false,
+ // User is unknown in this suite.
+ extraConf = Map(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString)
+ )
}
test("run Python application with Spark Connect in yarn-client mode") {
@@ -290,6 +299,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
testPySpark(
clientMode = false,
extraConf = Map(
+ PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString, // User is unknown in this suite.
"spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON"
-> sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", pythonExecutablePath),
"spark.yarn.appMasterEnv.PYSPARK_PYTHON"
diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh
index b7233e6e9bf3d..7dc241f097227 100755
--- a/sbin/spark-daemon.sh
+++ b/sbin/spark-daemon.sh
@@ -30,7 +30,7 @@
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
# SPARK_NO_DAEMONIZE If set, will run the proposed command in the foreground. It will not output a PID file.
##
-
+export SPARK_CONNECT_MODE=0
usage="Usage: spark-daemon.sh [--config ] (start|stop|submit|decommission|status) "
# if no args specified, show usage
diff --git a/sbin/start-connect-server.sh b/sbin/start-connect-server.sh
index 7f0c430a468a9..03e7a118f4590 100755
--- a/sbin/start-connect-server.sh
+++ b/sbin/start-connect-server.sh
@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+export SPARK_CONNECT_MODE=0
# Enter posix mode for bash
set -o posix
diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh
index a457526979341..b77459f0e57af 100755
--- a/sbin/start-thriftserver.sh
+++ b/sbin/start-thriftserver.sh
@@ -19,6 +19,7 @@
#
# Shell script for starting the Spark SQL Thrift server
+export SPARK_CONNECT_MODE=0
# Enter posix mode for bash
set -o posix
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index b868eea41b692..1f82c47e6abcb 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -222,6 +222,7 @@ DROP: 'DROP';
ELSE: 'ELSE';
ELSEIF: 'ELSEIF';
END: 'END';
+ENFORCED: 'ENFORCED';
ESCAPE: 'ESCAPE';
ESCAPED: 'ESCAPED';
EVOLUTION: 'EVOLUTION';
@@ -290,6 +291,7 @@ ITEMS: 'ITEMS';
ITERATE: 'ITERATE';
JOIN: 'JOIN';
JSON: 'JSON';
+KEY: 'KEY';
KEYS: 'KEYS';
LANGUAGE: 'LANGUAGE';
LAST: 'LAST';
@@ -337,6 +339,7 @@ NOT: 'NOT';
NULL: 'NULL';
NULLS: 'NULLS';
NUMERIC: 'NUMERIC';
+NORELY: 'NORELY';
OF: 'OF';
OFFSET: 'OFFSET';
ON: 'ON';
@@ -362,6 +365,8 @@ POSITION: 'POSITION';
PRECEDING: 'PRECEDING';
PRIMARY: 'PRIMARY';
PRINCIPALS: 'PRINCIPALS';
+PROCEDURE: 'PROCEDURE';
+PROCEDURES: 'PROCEDURES';
PROPERTIES: 'PROPERTIES';
PURGE: 'PURGE';
QUARTER: 'QUARTER';
@@ -376,6 +381,7 @@ RECURSIVE: 'RECURSIVE';
REDUCE: 'REDUCE';
REFERENCES: 'REFERENCES';
REFRESH: 'REFRESH';
+RELY: 'RELY';
RENAME: 'RENAME';
REPAIR: 'REPAIR';
REPEAT: 'REPEAT';
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 59a0b1ce7a3c5..e15441e14e429 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -177,6 +177,10 @@ singleTableSchema
: colTypeList EOF
;
+singleRoutineParamList
+ : colDefinitionList EOF
+ ;
+
statement
: query #statementDefault
| executeImmediate #visitExecuteImmediate
@@ -198,7 +202,7 @@ statement
(RESTRICT | CASCADE)? #dropNamespace
| SHOW namespaces ((FROM | IN) multipartIdentifier)?
(LIKE? pattern=stringLit)? #showNamespaces
- | createTableHeader (LEFT_PAREN colDefinitionList RIGHT_PAREN)? tableProvider?
+ | createTableHeader (LEFT_PAREN tableElementList RIGHT_PAREN)? tableProvider?
createTableClauses
(AS? query)? #createTable
| CREATE TABLE (IF errorCapturingNot EXISTS)? target=tableIdentifier
@@ -208,7 +212,7 @@ statement
createFileFormat |
locationSpec |
(TBLPROPERTIES tableProps=propertyList))* #createTableLike
- | replaceTableHeader (LEFT_PAREN colDefinitionList RIGHT_PAREN)? tableProvider?
+ | replaceTableHeader (LEFT_PAREN tableElementList RIGHT_PAREN)? tableProvider?
createTableClauses
(AS? query)? #replaceTable
| ANALYZE TABLE identifierReference partitionSpec? COMPUTE STATISTICS
@@ -261,6 +265,10 @@ statement
| ALTER TABLE identifierReference
(clusterBySpec | CLUSTER BY NONE) #alterClusterBy
| ALTER TABLE identifierReference collationSpec #alterTableCollation
+ | ALTER TABLE identifierReference ADD tableConstraintDefinition #addTableConstraint
+ | ALTER TABLE identifierReference
+ DROP CONSTRAINT (IF EXISTS)? name=identifier
+ (RESTRICT | CASCADE)? #dropTableConstraint
| DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable
| DROP VIEW (IF EXISTS)? identifierReference #dropView
| CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)?
@@ -304,10 +312,12 @@ statement
| SHOW PARTITIONS identifierReference partitionSpec? #showPartitions
| SHOW identifier? FUNCTIONS ((FROM | IN) ns=identifierReference)?
(LIKE? (legacy=multipartIdentifier | pattern=stringLit))? #showFunctions
+ | SHOW PROCEDURES ((FROM | IN) identifierReference)? #showProcedures
| SHOW CREATE TABLE identifierReference (AS SERDE)? #showCreateTable
| SHOW CURRENT namespace #showCurrentNamespace
| SHOW CATALOGS (LIKE? pattern=stringLit)? #showCatalogs
| (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction
+ | (DESC | DESCRIBE) PROCEDURE identifierReference #describeProcedure
| (DESC | DESCRIBE) namespace EXTENDED?
identifierReference #describeNamespace
| (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)?
@@ -1150,7 +1160,7 @@ datetimeUnit
;
primaryExpression
- : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER | USER | SESSION_USER) #currentLike
+ : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER | USER | SESSION_USER | CURRENT_TIME) #currentLike
| name=(TIMESTAMPADD | DATEADD | DATE_ADD) LEFT_PAREN (unit=datetimeUnit | invalidUnit=stringLit) COMMA unitsAmount=valueExpression COMMA timestamp=valueExpression RIGHT_PAREN #timestampadd
| name=(TIMESTAMPDIFF | DATEDIFF | DATE_DIFF | TIMEDIFF) LEFT_PAREN (unit=datetimeUnit | invalidUnit=stringLit) COMMA startTimestamp=valueExpression COMMA endTimestamp=valueExpression RIGHT_PAREN #timestampdiff
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
@@ -1190,6 +1200,7 @@ primaryExpression
literalType
: DATE
+ | TIME
| TIMESTAMP | TIMESTAMP_LTZ | TIMESTAMP_NTZ
| INTERVAL
| BINARY_HEX
@@ -1279,6 +1290,7 @@ type
| FLOAT | REAL
| DOUBLE
| DATE
+ | TIME
| TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ
| STRING collateClause?
| CHARACTER | CHAR
@@ -1334,6 +1346,15 @@ colType
: colName=errorCapturingIdentifier dataType (errorCapturingNot NULL)? commentSpec?
;
+tableElementList
+ : tableElement (COMMA tableElement)*
+ ;
+
+tableElement
+ : tableConstraintDefinition
+ | colDefinition
+ ;
+
colDefinitionList
: colDefinition (COMMA colDefinition)*
;
@@ -1347,6 +1368,7 @@ colDefinitionOption
| defaultExpression
| generationExpression
| commentSpec
+ | columnConstraintDefinition
;
generationExpression
@@ -1516,6 +1538,62 @@ number
| MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
;
+columnConstraintDefinition
+ : (CONSTRAINT name=errorCapturingIdentifier)? columnConstraint constraintCharacteristic*
+ ;
+
+columnConstraint
+ : checkConstraint
+ | uniqueSpec
+ | referenceSpec
+ ;
+
+tableConstraintDefinition
+ : (CONSTRAINT name=errorCapturingIdentifier)? tableConstraint constraintCharacteristic*
+ ;
+
+tableConstraint
+ : checkConstraint
+ | uniqueConstraint
+ | foreignKeyConstraint
+ ;
+
+checkConstraint
+ : CHECK LEFT_PAREN (expr=booleanExpression) RIGHT_PAREN
+ ;
+
+uniqueSpec
+ : UNIQUE
+ | PRIMARY KEY
+ ;
+
+uniqueConstraint
+ : uniqueSpec identifierList
+ ;
+
+referenceSpec
+ : REFERENCES multipartIdentifier (parentColumns=identifierList)?
+ ;
+
+foreignKeyConstraint
+ : FOREIGN KEY identifierList referenceSpec
+ ;
+
+constraintCharacteristic
+ : enforcedCharacteristic
+ | relyCharacteristic
+ ;
+
+enforcedCharacteristic
+ : ENFORCED
+ | NOT ENFORCED
+ ;
+
+relyCharacteristic
+ : RELY
+ | NORELY
+ ;
+
alterColumnSpecList
: alterColumnSpec (COMMA alterColumnSpec)*
;
@@ -1673,6 +1751,7 @@ ansiNonReserved
| DOUBLE
| DROP
| ELSEIF
+ | ENFORCED
| ESCAPED
| EVOLUTION
| EXCHANGE
@@ -1722,6 +1801,7 @@ ansiNonReserved
| ITEMS
| ITERATE
| JSON
+ | KEY
| KEYS
| LANGUAGE
| LAST
@@ -1761,6 +1841,7 @@ ansiNonReserved
| NANOSECONDS
| NO
| NONE
+ | NORELY
| NULLS
| NUMERIC
| OF
@@ -1780,6 +1861,8 @@ ansiNonReserved
| POSITION
| PRECEDING
| PRINCIPALS
+ | PROCEDURE
+ | PROCEDURES
| PROPERTIES
| PURGE
| QUARTER
@@ -1792,6 +1875,7 @@ ansiNonReserved
| RECOVER
| REDUCE
| REFRESH
+ | RELY
| RENAME
| REPAIR
| REPEAT
@@ -2028,6 +2112,7 @@ nonReserved
| ELSE
| ELSEIF
| END
+ | ENFORCED
| ESCAPE
| ESCAPED
| EVOLUTION
@@ -2091,6 +2176,7 @@ nonReserved
| ITEMS
| ITERATE
| JSON
+ | KEY
| KEYS
| LANGUAGE
| LAST
@@ -2132,6 +2218,7 @@ nonReserved
| NANOSECONDS
| NO
| NONE
+ | NORELY
| NOT
| NULL
| NULLS
@@ -2160,6 +2247,8 @@ nonReserved
| PRECEDING
| PRIMARY
| PRINCIPALS
+ | PROCEDURE
+ | PROCEDURES
| PROPERTIES
| PURGE
| QUARTER
@@ -2174,6 +2263,7 @@ nonReserved
| REDUCE
| REFERENCES
| REFRESH
+ | RELY
| RENAME
| REPAIR
| REPEAT
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
index 7f5eed1eb1ade..88d597fdfbb73 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
@@ -803,6 +803,26 @@ class Column(val node: ColumnNode) extends Logging with TableValuedFunctionArgum
*/
def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala)
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the provided Dataset/DataFrame.
+ *
+ * @group subquery
+ * @since 4.1.0
+ */
+ def isin(ds: Dataset[_]): Column = {
+ if (ds == null) {
+ // A single null should be handled as a value.
+ isin(Seq(ds): _*)
+ } else {
+ val values = node match {
+ case internal.UnresolvedFunction("struct", arguments, _, _, _, _) => arguments
+ case _ => Seq(node)
+ }
+ Column(internal.SubqueryExpression(ds, internal.SubqueryType.IN(values)))
+ }
+ }
+
/**
* SQL like expression. Returns a boolean column based on a SQL LIKE match.
*
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
index 4952fa36f66ee..c287ad69de2e8 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1758,7 +1758,9 @@ abstract class Dataset[T] extends Serializable {
* @group subquery
* @since 4.0.0
*/
- def scalar(): Column
+ def scalar(): Column = {
+ Column(internal.SubqueryExpression(this, internal.SubqueryType.SCALAR))
+ }
/**
* Return a `Column` object for an EXISTS Subquery.
@@ -1771,7 +1773,9 @@ abstract class Dataset[T] extends Serializable {
* @group subquery
* @since 4.0.0
*/
- def exists(): Column
+ def exists(): Column = {
+ Column(internal.SubqueryExpression(this, internal.SubqueryType.EXISTS))
+ }
/**
* Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
index 4957d76af9a29..94a627fd17a64 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -147,6 +147,14 @@ object Encoders {
*/
def INSTANT: Encoder[java.time.Instant] = STRICT_INSTANT_ENCODER
+ /**
+ * Creates an encoder that serializes instances of the `java.time.LocalTime` class to the
+ * internal representation of nullable Catalyst's TimeType.
+ *
+ * @since 4.1.0
+ */
+ def LOCALTIME: Encoder[java.time.LocalTime] = LocalTimeEncoder
+
/**
* An encoder for arrays of bytes.
*
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 2f68d436acfcd..a5b1060ca03db 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -119,6 +119,9 @@ trait EncoderImplicits extends LowPrioritySQLImplicits with Serializable {
/** @since 3.0.0 */
implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT
+ /** @since 4.1.0 */
+ implicit def newLocalTimeEncoder: Encoder[java.time.LocalTime] = Encoders.LOCALTIME
+
/** @since 3.2.0 */
implicit def newDurationEncoder: Encoder[java.time.Duration] = Encoders.DURATION
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index 3068b81c58c82..57b77d27b1265 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -141,6 +141,7 @@ abstract class Catalog {
* is an unqualified name that designates a table/view.
* @since 2.0.0
*/
+ @deprecated("use listColumns(tableName: String) instead.", "4.0.0")
@throws[AnalysisException]("database or table does not exist")
def listColumns(dbName: String, tableName: String): Dataset[Column]
@@ -175,6 +176,7 @@ abstract class Catalog {
*
* @since 2.1.0
*/
+ @deprecated("use getTable(tableName: String) instead.", "4.0.0")
@throws[AnalysisException]("database or table does not exist")
def getTable(dbName: String, tableName: String): Table
@@ -204,6 +206,7 @@ abstract class Catalog {
* is an unqualified name that designates a function in the specified database
* @since 2.1.0
*/
+ @deprecated("use getFunction(functionName: String) instead.", "4.0.0")
@throws[AnalysisException]("database or function does not exist")
def getFunction(dbName: String, functionName: String): Function
@@ -240,6 +243,7 @@ abstract class Catalog {
* is an unqualified name that designates a table.
* @since 2.1.0
*/
+ @deprecated("use tableExists(tableName: String) instead.", "4.0.0")
def tableExists(dbName: String, tableName: String): Boolean
/**
@@ -267,6 +271,7 @@ abstract class Catalog {
* is an unqualified name that designates a function.
* @since 2.1.0
*/
+ @deprecated("use functionExists(functionName: String) instead.", "4.0.0")
def functionExists(dbName: String, functionName: String): Boolean
/**
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 8d0103ca69635..327ec5aa344c1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.commons.lang3.reflect.{TypeUtils => JavaTypeUtils}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
@@ -89,6 +89,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.math.BigDecimal] => DEFAULT_JAVA_DECIMAL_ENCODER
case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder
case c: Class[_] if c == classOf[java.time.LocalDate] => STRICT_LOCAL_DATE_ENCODER
+ case c: Class[_] if c == classOf[java.time.LocalTime] => LocalTimeEncoder
case c: Class[_] if c == classOf[java.sql.Date] => STRICT_DATE_ENCODER
case c: Class[_] if c == classOf[java.time.Instant] => STRICT_INSTANT_ENCODER
case c: Class[_] if c == classOf[java.sql.Timestamp] => STRICT_TIMESTAMP_ENCODER
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index cd12cbd267cc4..d2e0053597e4f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -330,6 +330,7 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => STRICT_TIMESTAMP_ENCODER
case t if isSubtype(t, localTypeOf[java.time.Instant]) => STRICT_INSTANT_ENCODER
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder
+ case t if isSubtype(t, localTypeOf[java.time.LocalTime]) => LocalTimeEncoder
case t if isSubtype(t, localTypeOf[VariantVal]) => VariantEncoder
case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index d998502ac1b25..1dd939131ab96 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.encoders
import java.{sql => jsql}
import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt}
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period}
import scala.reflect.{classTag, ClassTag}
@@ -249,6 +249,7 @@ object AgnosticEncoders {
case class InstantEncoder(override val lenientSerialization: Boolean)
extends LeafEncoder[Instant](TimestampType)
case object LocalDateTimeEncoder extends LeafEncoder[LocalDateTime](TimestampNTZType)
+ case object LocalTimeEncoder extends LeafEncoder[LocalTime](TimeType())
case class SparkDecimalEncoder(dt: DecimalType) extends LeafEncoder[Decimal](dt)
case class ScalaDecimalEncoder(dt: DecimalType) extends LeafEncoder[BigDecimal](dt)
@@ -276,11 +277,14 @@ object AgnosticEncoders {
* another encoder. This is fallback for scenarios where objects can't be represented using
* standard encoders, an example of this is where we use a different (opaque) serialization
* format (i.e. java serialization, kryo serialization, or protobuf).
+ * @param nullable
+ * defaults to false indicating the codec guarantees decode / encode results are non-nullable
*/
case class TransformingEncoder[I, O](
clsTag: ClassTag[I],
transformed: AgnosticEncoder[O],
- codecProvider: () => Codec[_ >: I, O])
+ codecProvider: () => Codec[_ >: I, O],
+ override val nullable: Boolean = false)
extends AgnosticEncoder[I] {
override def isPrimitive: Boolean = transformed.isPrimitive
override def dataType: DataType = transformed.dataType
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 7260ff8f9fefd..d5692bb85c4e9 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.reflect.classTag
import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.errors.{DataTypeErrorsBase, ExecutionErrors}
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types._
@@ -49,6 +49,7 @@ import org.apache.spark.util.ArrayImplicits._
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
*
* TimestampNTZType -> java.time.LocalDateTime
+ * TimeType -> java.time.LocalTime
*
* DayTimeIntervalType -> java.time.Duration
* YearMonthIntervalType -> java.time.Period
@@ -90,6 +91,7 @@ object RowEncoder extends DataTypeErrorsBase {
case TimestampNTZType => LocalDateTimeEncoder
case DateType if SqlApiConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient)
case DateType => DateEncoder(lenient)
+ case _: TimeType => LocalTimeEncoder
case CalendarIntervalType => CalendarIntervalEncoder
case _: DayTimeIntervalType => DayTimeIntervalEncoder
case _: YearMonthIntervalType => YearMonthIntervalEncoder
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
index 94e014fb77f1b..bf9a250d6499e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin}
import org.apache.spark.sql.connector.catalog.IdentityColumnSpec
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.internal.SqlApiConf
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, TimeType, VarcharType, VariantType, YearMonthIntervalType}
class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
protected def typedVisit[T](ctx: ParseTree): T = {
@@ -79,6 +79,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
case (FLOAT | REAL, Nil) => FloatType
case (DOUBLE, Nil) => DoubleType
case (DATE, Nil) => DateType
+ case (TIME, Nil) => TimeType(TimeType.MICROS_PRECISION)
+ case (TIME, precision :: Nil) => TimeType(precision.getText.toInt)
case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType
case (TIMESTAMP_NTZ, Nil) => TimestampNTZType
case (TIMESTAMP_LTZ, Nil) => TimestampType
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
index 54af195847dac..28fccd2092b34 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
@@ -99,7 +99,6 @@ abstract class AbstractParser extends DataTypeParserInterface with Logging {
throw new ParseException(
command = Option(command),
start = e.origin,
- stop = e.origin,
errorClass = e.getCondition,
messageParameters = e.getMessageParameters.asScala.toMap,
queryContext = e.getQueryContext)
@@ -158,24 +157,19 @@ case object ParseErrorListener extends BaseErrorListener {
charPositionInLine: Int,
msg: String,
e: RecognitionException): Unit = {
- val (start, stop) = offendingSymbol match {
+ val start = offendingSymbol match {
case token: CommonToken =>
- val start = Origin(Some(line), Some(token.getCharPositionInLine))
- val length = token.getStopIndex - token.getStartIndex + 1
- val stop = Origin(Some(line), Some(token.getCharPositionInLine + length))
- (start, stop)
+ Origin(Some(line), Some(token.getCharPositionInLine))
case _ =>
- val start = Origin(Some(line), Some(charPositionInLine))
- (start, start)
+ Origin(Some(line), Some(charPositionInLine))
}
e match {
case sre: SparkRecognitionException if sre.errorClass.isDefined =>
- throw new ParseException(None, start, stop, sre.errorClass.get, sre.messageParameters)
+ throw new ParseException(None, start, sre.errorClass.get, sre.messageParameters)
case _ =>
throw new ParseException(
command = None,
start = start,
- stop = stop,
errorClass = "PARSE_SYNTAX_ERROR",
messageParameters = Map("error" -> msg, "hint" -> ""))
}
@@ -190,7 +184,6 @@ class ParseException private (
val command: Option[String],
message: String,
val start: Origin,
- val stop: Origin,
errorClass: Option[String] = None,
messageParameters: Map[String, String] = Map.empty,
queryContext: Array[QueryContext] = ParseException.getQueryContext())
@@ -208,24 +201,22 @@ class ParseException private (
Option(SparkParserUtils.command(ctx)),
SparkThrowableHelper.getMessage(errorClass, messageParameters),
SparkParserUtils.position(ctx.getStart),
- SparkParserUtils.position(ctx.getStop),
Some(errorClass),
messageParameters)
- def this(errorClass: String, ctx: ParserRuleContext) = this(errorClass, Map.empty, ctx)
+ def this(errorClass: String, ctx: ParserRuleContext) =
+ this(errorClass = errorClass, messageParameters = Map.empty, ctx = ctx)
/** Compose the message through SparkThrowableHelper given errorClass and messageParameters. */
def this(
command: Option[String],
start: Origin,
- stop: Origin,
errorClass: String,
messageParameters: Map[String, String]) =
this(
command,
SparkThrowableHelper.getMessage(errorClass, messageParameters),
start,
- stop,
Some(errorClass),
messageParameters,
queryContext = ParseException.getQueryContext())
@@ -233,7 +224,6 @@ class ParseException private (
def this(
command: Option[String],
start: Origin,
- stop: Origin,
errorClass: String,
messageParameters: Map[String, String],
queryContext: Array[QueryContext]) =
@@ -241,7 +231,6 @@ class ParseException private (
command,
SparkThrowableHelper.getMessage(errorClass, messageParameters),
start,
- stop,
Some(errorClass),
messageParameters,
queryContext)
@@ -282,7 +271,7 @@ class ParseException private (
} else {
(cl, messageParameters)
}
- new ParseException(Option(cmd), start, stop, newCl, params, queryContext)
+ new ParseException(Option(cmd), start, newCl, params, queryContext)
}
override def getQueryContext: Array[QueryContext] = queryContext
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
index e870a83ec4ae6..da454c1c4214e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
@@ -21,7 +21,13 @@ import java.util.Locale
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.streaming.TimeMode
-/** TimeMode types used in transformWithState operator */
+/**
+ * TimeMode types used in transformWithState operator
+ *
+ * Note that we need to keep TimeMode.None() named as "NoTime" in case class here because a case
+ * class named "None" will introduce naming collision with scala native type None. See SPARK-51151
+ * for more info.
+ */
case object NoTime extends TimeMode
case object ProcessingTime extends TimeMode
@@ -31,7 +37,7 @@ case object EventTime extends TimeMode
object TimeModes {
def apply(timeMode: String): TimeMode = {
timeMode.toLowerCase(Locale.ROOT) match {
- case "none" =>
+ case "none" | "notime" =>
NoTime
case "processingtime" =>
ProcessingTime
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
index 71777906f868e..cf48cce931bc4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
@@ -264,19 +264,29 @@ private object DateTimeFormatterHelper {
toFormatter(builder, locale)
}
- lazy val fractionFormatter: DateTimeFormatter = {
- val builder = createBuilder()
- .append(DateTimeFormatter.ISO_LOCAL_DATE)
- .appendLiteral(' ')
+ private def appendTimeBuilder(builder: DateTimeFormatterBuilder) = {
+ builder
.appendValue(ChronoField.HOUR_OF_DAY, 2)
.appendLiteral(':')
.appendValue(ChronoField.MINUTE_OF_HOUR, 2)
.appendLiteral(':')
.appendValue(ChronoField.SECOND_OF_MINUTE, 2)
.appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true)
+ }
+
+ lazy val fractionFormatter: DateTimeFormatter = {
+ val dateBuilder = createBuilder()
+ .append(DateTimeFormatter.ISO_LOCAL_DATE)
+ .appendLiteral(' ')
+ val builder = appendTimeBuilder(dateBuilder)
toFormatter(builder, TimestampFormatter.defaultLocale)
}
+ lazy val fractionTimeFormatter: DateTimeFormatter = {
+ val builder = appendTimeBuilder(createBuilder())
+ toFormatter(builder, TimeFormatter.defaultLocale)
+ }
+
// SPARK-31892: The week-based date fields are rarely used and really confusing for parsing values
// to datetime, especially when they are mixed with other non-week-based ones;
// SPARK-31879: It's also difficult for us to restore the behavior of week-based date fields
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala
index 4d05f9079548c..b16ee9ad1929a 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util
import java.lang.invoke.{MethodHandles, MethodType}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate, LocalDateTime, LocalTime, ZonedDateTime, ZoneId, ZoneOffset}
+import java.time.temporal.ChronoField.MICRO_OF_DAY
import java.util.TimeZone
import java.util.concurrent.TimeUnit.{MICROSECONDS, NANOSECONDS}
import java.util.regex.Pattern
@@ -29,7 +30,7 @@ import org.apache.spark.QueryContext
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, rebaseJulianToGregorianDays, rebaseJulianToGregorianMicros}
import org.apache.spark.sql.errors.ExecutionErrors
-import org.apache.spark.sql.types.{DateType, TimestampType}
+import org.apache.spark.sql.types.{DateType, TimestampType, TimeType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.SparkClassUtils
@@ -132,6 +133,39 @@ trait SparkDateTimeUtils {
}
}
+ /**
+ * Gets the number of microseconds since midnight using the session time zone.
+ */
+ def instantToMicrosOfDay(instant: Instant, timezone: String): Long = {
+ val zoneId = getZoneId(timezone)
+ val localDateTime = LocalDateTime.ofInstant(instant, zoneId)
+ localDateTime.toLocalTime.getLong(MICRO_OF_DAY)
+ }
+
+ /**
+ * Truncates a time value (in microseconds) to the specified fractional precision `p`.
+ *
+ * For example, if `p = 3`, we keep millisecond resolution and discard any digits beyond the
+ * thousand-microsecond place. So a value like `123456` microseconds (12:34:56.123456) becomes
+ * `123000` microseconds (12:34:56.123).
+ *
+ * @param micros
+ * The original time in microseconds.
+ * @param p
+ * The fractional second precision (range 0 to 6).
+ * @return
+ * The truncated microsecond value, preserving only `p` fractional digits.
+ */
+ def truncateTimeMicrosToPrecision(micros: Long, p: Int): Long = {
+ assert(
+ p >= TimeType.MIN_PRECISION && p <= TimeType.MICROS_PRECISION,
+ s"Fractional second precision $p out" +
+ s" of range [${TimeType.MIN_PRECISION}..${TimeType.MICROS_PRECISION}].")
+ val scale = TimeType.MICROS_PRECISION - p
+ val factor = math.pow(10, scale).toLong
+ (micros / factor) * factor
+ }
+
/**
* Converts the timestamp `micros` from one timezone to another.
*
@@ -184,6 +218,19 @@ trait SparkDateTimeUtils {
instantToMicros(instant)
}
+ /**
+ * Converts the local time to the number of microseconds within the day, from 0 to (24 * 60 * 60
+ * * 1000000) - 1.
+ */
+ def localTimeToMicros(localTime: LocalTime): Long = localTime.getLong(MICRO_OF_DAY)
+
+ /**
+ * Converts the number of microseconds within the day to the local time.
+ */
+ def microsToLocalTime(micros: Long): LocalTime = {
+ LocalTime.ofNanoOfDay(Math.multiplyExact(micros, NANOS_PER_MICROS))
+ }
+
/**
* Converts a local date at the default JVM time zone to the number of days since 1970-01-01 in
* the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are
@@ -646,6 +693,35 @@ trait SparkDateTimeUtils {
}
}
+ /**
+ * Trims and parses a given UTF8 string to a corresponding [[Long]] value which representing the
+ * number of microseconds since the midnight. The result will be independent of time zones.
+ *
+ * The return type is [[Option]] in order to distinguish between 0L and null. Please refer to
+ * `parseTimestampString` for the allowed formats.
+ */
+ def stringToTime(s: UTF8String): Option[Long] = {
+ try {
+ val (segments, zoneIdOpt, justTime) = parseTimestampString(s)
+ // If the input string can't be parsed as a time, or it contains not only
+ // the time part or has time zone information, return None.
+ if (segments.isEmpty || !justTime || zoneIdOpt.isDefined) {
+ return None
+ }
+ val nanoseconds = MICROSECONDS.toNanos(segments(6))
+ val localTime = LocalTime.of(segments(3), segments(4), segments(5), nanoseconds.toInt)
+ Some(localTimeToMicros(localTime))
+ } catch {
+ case NonFatal(_) => None
+ }
+ }
+
+ def stringToTimeAnsi(s: UTF8String, context: QueryContext = null): Long = {
+ stringToTime(s).getOrElse {
+ throw ExecutionErrors.invalidInputInCastToDatetimeError(s, TimeType(), context)
+ }
+ }
+
/**
* Returns the index of the first non-whitespace and non-ISO control character in the byte
* array.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
index 01ee899085701..9c9e623e03395 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
@@ -127,6 +127,14 @@ trait SparkParserUtils {
}
}
+ /** Get the code that creates the given node. */
+ def source(ctx: ParserRuleContext): String = {
+ // Note: `exprCtx.getText` returns a string without spaces, so we need to
+ // get the text from the underlying char stream instead.
+ val stream = ctx.getStart.getInputStream
+ stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
+ }
+
/** Convert a string token into a string. */
def string(token: Token): String = unescapeSQLString(token.getText)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimeFormatter.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimeFormatter.scala
new file mode 100644
index 0000000000000..46afbc8aca196
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimeFormatter.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import java.time.LocalTime
+import java.time.format.DateTimeFormatter
+import java.util.Locale
+
+import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._
+import org.apache.spark.unsafe.types.UTF8String
+
+sealed trait TimeFormatter extends Serializable {
+ def parse(s: String): Long // returns microseconds since midnight
+
+ def format(localTime: LocalTime): String
+ // Converts microseconds since the midnight to time string
+ def format(micros: Long): String
+
+ def validatePatternString(): Unit
+}
+
+/**
+ * The ISO time formatter is capable of formatting and parsing the ISO-8601 extended time format.
+ */
+class Iso8601TimeFormatter(pattern: String, locale: Locale, isParsing: Boolean)
+ extends TimeFormatter
+ with DateTimeFormatterHelper {
+
+ @transient
+ protected lazy val formatter: DateTimeFormatter =
+ getOrCreateFormatter(pattern, locale, isParsing)
+
+ override def parse(s: String): Long = {
+ val localTime = toLocalTime(formatter.parse(s))
+ localTimeToMicros(localTime)
+ }
+
+ override def format(localTime: LocalTime): String = {
+ localTime.format(formatter)
+ }
+
+ override def format(micros: Long): String = {
+ format(microsToLocalTime(micros))
+ }
+
+ override def validatePatternString(): Unit = {
+ try {
+ formatter
+ } catch checkInvalidPattern(pattern)
+ ()
+ }
+}
+
+/**
+ * The formatter parses/formats times according to the pattern `HH:mm:ss.[..fff..]` where
+ * `[..fff..]` is a fraction of second up to microsecond resolution. The formatter does not output
+ * trailing zeros in the fraction. For example, the time `15:00:01.123400` is formatted as the
+ * string `15:00:01.1234`.
+ */
+class FractionTimeFormatter
+ extends Iso8601TimeFormatter(
+ TimeFormatter.defaultPattern,
+ TimeFormatter.defaultLocale,
+ isParsing = false) {
+
+ @transient
+ override protected lazy val formatter: DateTimeFormatter =
+ DateTimeFormatterHelper.fractionTimeFormatter
+}
+
+/**
+ * The formatter for time values which doesn't require users to specify a pattern. While
+ * formatting, it uses the default pattern [[TimeFormatter.defaultPattern()]]. In parsing, it
+ * follows the CAST logic in conversion of strings to Catalyst's TimeType.
+ *
+ * @param locale
+ * The locale overrides the system locale and is used in formatting.
+ * @param isParsing
+ * Whether the formatter is used for parsing (`true`) or for formatting (`false`).
+ */
+class DefaultTimeFormatter(locale: Locale, isParsing: Boolean)
+ extends Iso8601TimeFormatter(TimeFormatter.defaultPattern, locale, isParsing) {
+
+ override def parse(s: String): Long = {
+ SparkDateTimeUtils.stringToTimeAnsi(UTF8String.fromString(s))
+ }
+}
+
+object TimeFormatter {
+
+ val defaultLocale: Locale = Locale.US
+
+ val defaultPattern: String = "HH:mm:ss"
+
+ private def getFormatter(
+ format: Option[String],
+ locale: Locale = defaultLocale,
+ isParsing: Boolean): TimeFormatter = {
+ val formatter = format
+ .map(new Iso8601TimeFormatter(_, locale, isParsing))
+ .getOrElse(new DefaultTimeFormatter(locale, isParsing))
+ formatter.validatePatternString()
+ formatter
+ }
+
+ def apply(format: String, locale: Locale, isParsing: Boolean): TimeFormatter = {
+ getFormatter(Some(format), locale, isParsing)
+ }
+
+ def apply(format: Option[String], isParsing: Boolean): TimeFormatter = {
+ getFormatter(format, defaultLocale, isParsing)
+ }
+
+ def apply(format: String, isParsing: Boolean): TimeFormatter = apply(Some(format), isParsing)
+
+ def apply(format: String): TimeFormatter = {
+ getFormatter(Some(format), defaultLocale, isParsing = false)
+ }
+
+ def apply(isParsing: Boolean): TimeFormatter = {
+ getFormatter(None, defaultLocale, isParsing)
+ }
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
index 4fcb84daf187d..15784e9762e35 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
@@ -545,7 +545,8 @@ object TimestampFormatter {
val defaultLocale: Locale = Locale.US
- def defaultPattern(): String = s"${DateFormatter.defaultPattern} HH:mm:ss"
+ def defaultPattern(): String =
+ s"${DateFormatter.defaultPattern} ${TimeFormatter.defaultPattern}"
private def getFormatter(
format: Option[String],
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
index c69c5bfb52616..1e2b2e691cd31 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
@@ -207,7 +207,8 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
messageParameters = Map(
"expression" -> convertedValueStr,
"sourceType" -> toSQLType(StringType),
- "targetType" -> toSQLType(to)),
+ "targetType" -> toSQLType(to),
+ "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")),
context = getQueryContext(context),
summary = getSummary(context))
}
@@ -225,8 +226,11 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
def castingCauseOverflowError(t: String, from: DataType, to: DataType): ArithmeticException = {
new SparkArithmeticException(
errorClass = "CAST_OVERFLOW",
- messageParameters =
- Map("value" -> t, "sourceType" -> toSQLType(from), "targetType" -> toSQLType(to)),
+ messageParameters = Map(
+ "value" -> t,
+ "sourceType" -> toSQLType(from),
+ "targetType" -> toSQLType(to),
+ "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")),
context = Array.empty,
summary = "")
}
@@ -264,4 +268,11 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
errorClass = "_LEGACY_ERROR_TEMP_1189",
messageParameters = Map("operation" -> operation))
}
+
+ def unsupportedTimePrecisionError(precision: Int): Throwable = {
+ new SparkException(
+ errorClass = "UNSUPPORTED_TIME_PRECISION",
+ messageParameters = Map("precision" -> precision.toString),
+ cause = null)
+ }
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index fca3ea8fdb908..8124b1a4ab197 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -73,7 +73,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
def failToRecognizePatternError(pattern: String, e: Throwable): SparkRuntimeException = {
new SparkRuntimeException(
- errorClass = "_LEGACY_ERROR_TEMP_2130",
+ errorClass = "INVALID_DATETIME_PATTERN.WITH_SUGGESTION",
messageParameters =
Map("pattern" -> toSQLValue(pattern), "docroot" -> SparkBuildInfo.spark_doc_root),
cause = e)
@@ -109,7 +109,8 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
messageParameters = Map(
"expression" -> sqlValue,
"sourceType" -> toSQLType(from),
- "targetType" -> toSQLType(to)),
+ "targetType" -> toSQLType(to),
+ "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")),
context = getQueryContext(context),
summary = getSummary(context))
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index 0bd9f38014984..3f6c20679e219 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -500,7 +500,6 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
new ParseException(
command = Option(sqlText),
start = position,
- stop = position,
errorClass = "INVALID_SQL_SYNTAX.UNSUPPORTED_SQL_STATEMENT",
messageParameters = Map("sqlText" -> sqlText))
}
@@ -632,7 +631,6 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
new ParseException(
command = origin.sqlText,
start = origin,
- stop = origin,
errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
messageParameters = Map("statement" -> statement))
}
@@ -656,6 +654,16 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
ctx)
}
+ def createFuncWithGeneratedColumnsError(ctx: ParserRuleContext): Throwable = {
+ new ParseException(
+ errorClass = "INVALID_SQL_SYNTAX.CREATE_FUNC_WITH_GENERATED_COLUMNS_AS_PARAMETERS",
+ ctx)
+ }
+
+ def createFuncWithConstraintError(ctx: ParserRuleContext): Throwable = {
+ new ParseException(errorClass = "INVALID_SQL_SYNTAX.CREATE_FUNC_WITH_COLUMN_CONSTRAINTS", ctx)
+ }
+
def defineTempFuncWithIfNotExistsError(ctx: ParserRuleContext): Throwable = {
new ParseException(errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", ctx)
}
@@ -690,7 +698,6 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
new ParseException(
command = Some(command),
start = start,
- stop = stop,
errorClass = "UNCLOSED_BRACKETED_COMMENT",
messageParameters = Map.empty)
}
@@ -791,4 +798,20 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
def clusterByWithBucketing(ctx: ParserRuleContext): Throwable = {
new ParseException(errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", ctx)
}
+
+ def invalidConstraintCharacteristics(
+ ctx: ParserRuleContext,
+ characteristics: String): Throwable = {
+ new ParseException(
+ errorClass = "INVALID_CONSTRAINT_CHARACTERISTICS",
+ messageParameters = Map("characteristics" -> characteristics),
+ ctx)
+ }
+
+ def multiplePrimaryKeysError(ctx: ParserRuleContext, columns: String): Throwable = {
+ new ParseException(
+ errorClass = "MULTIPLE_PRIMARY_KEYS",
+ messageParameters = Map("columns" -> columns),
+ ctx)
+ }
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
index 81f1e3153df33..ce5c76807b5c1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
@@ -6910,7 +6910,7 @@ object functions {
*/
// scalastyle:on line.size.limit
def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = {
- from_json(e, lit(schema.json), options.iterator)
+ from_json(e, lit(schema.sql), options.iterator)
}
// scalastyle:off line.size.limit
@@ -7782,7 +7782,7 @@ object functions {
*/
// scalastyle:on line.size.limit
def from_xml(e: Column, schema: StructType, options: java.util.Map[String, String]): Column =
- from_xml(e, lit(schema.json), options.asScala.iterator)
+ from_xml(e, lit(schema.sql), options.asScala.iterator)
// scalastyle:off line.size.limit
/**
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
index 463307409839d..4a7339165cb57 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
@@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicLong
import ColumnNode._
+import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.catalyst.util.AttributeNameParser
import org.apache.spark.sql.errors.DataTypeErrorsBase
@@ -651,3 +652,24 @@ private[sql] case class LazyExpression(
override def sql: String = "lazy" + argumentsToSql(Seq(child))
override private[internal] def children: Seq[ColumnNodeLike] = Seq(child)
}
+
+sealed trait SubqueryType
+
+object SubqueryType {
+ case object SCALAR extends SubqueryType
+ case object EXISTS extends SubqueryType
+ case class IN(values: Seq[ColumnNode]) extends SubqueryType
+}
+
+case class SubqueryExpression(
+ ds: Dataset[_],
+ subqueryType: SubqueryType,
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
+ override def sql: String = subqueryType match {
+ case SubqueryType.SCALAR => s"($ds)"
+ case SubqueryType.IN(values) => s"(${values.map(_.sql).mkString(",")}) IN ($ds)"
+ case _ => s"$subqueryType ($ds)"
+ }
+ override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
index db7e7c0ae1885..f798276d60f7c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -174,7 +174,7 @@ object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
private val otherTypes = {
- Seq(
+ (Seq(
NullType,
DateType,
TimestampType,
@@ -202,7 +202,8 @@ object DataType {
YearMonthIntervalType(MONTH),
YearMonthIntervalType(YEAR, MONTH),
TimestampNTZType,
- VariantType)
+ VariantType) ++
+ (TimeType.MIN_PRECISION to TimeType.MAX_PRECISION).map(TimeType(_)))
.map(t => t.typeName -> t)
.toMap
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
index a0e08745d8af8..4c51980d4e6c4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
@@ -150,10 +150,13 @@ case class StructField(
/**
* Return the default value of this StructField. This is used for storing the default value of a
* function parameter.
+ *
+ * It is present when the field represents a function parameter with a default value, such as
+ * `CREATE FUNCTION f(arg INT DEFAULT 42) RETURN ...`.
*/
private[sql] def getDefault(): Option[String] = {
- if (metadata.contains("default")) {
- Option(metadata.getString("default"))
+ if (metadata.contains(StructType.SQL_FUNCTION_DEFAULT_METADATA_KEY)) {
+ Option(metadata.getString(StructType.SQL_FUNCTION_DEFAULT_METADATA_KEY))
} else {
None
}
@@ -183,6 +186,9 @@ case class StructField(
/**
* Return the current default value of this StructField.
+ *
+ * It is present only when the field represents a table column with a default value, such as:
+ * `ALTER TABLE t ALTER COLUMN c SET DEFAULT 42`.
*/
def getCurrentDefaultValue(): Option[String] = {
if (metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) {
@@ -214,7 +220,12 @@ case class StructField(
}
}
- private def getDDLDefault = getCurrentDefaultValue()
+ private[sql] def hasExistenceDefaultValue: Boolean = {
+ metadata.contains(EXISTS_DEFAULT_COLUMN_METADATA_KEY)
+ }
+
+ private def getDDLDefault = getDefault()
+ .orElse(getCurrentDefaultValue())
.map(" DEFAULT " + _)
.getOrElse("")
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
index cc95d8ee94b02..4c49d3a58f4fc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -521,6 +521,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*/
@Stable
object StructType extends AbstractDataType {
+ private[sql] val SQL_FUNCTION_DEFAULT_METADATA_KEY = "default"
override private[sql] def defaultConcreteType: DataType = new StructType
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/TimeType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/TimeType.scala
new file mode 100644
index 0000000000000..c42311c6a1dcc
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/TimeType.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.annotation.Unstable
+import org.apache.spark.sql.errors.DataTypeErrors
+
+/**
+ * The time type represents a time value with fields hour, minute, second, up to microseconds. The
+ * range of times supported is 00:00:00.000000 to 23:59:59.999999.
+ *
+ * @param precision
+ * The time fractional seconds precision which indicates the number of decimal digits maintained
+ * following the decimal point in the seconds value. The supported range is [0, 6].
+ *
+ * @since 4.1.0
+ */
+@Unstable
+case class TimeType(precision: Int) extends DatetimeType {
+
+ if (precision < TimeType.MIN_PRECISION || precision > TimeType.MAX_PRECISION) {
+ throw DataTypeErrors.unsupportedTimePrecisionError(precision)
+ }
+
+ /**
+ * The default size of a value of the TimeType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
+
+ override def typeName: String = s"time($precision)"
+
+ private[spark] override def asNullable: TimeType = this
+}
+
+object TimeType {
+ val MIN_PRECISION: Int = 0
+ val MICROS_PRECISION: Int = 6
+ val MAX_PRECISION: Int = MICROS_PRECISION
+
+ def apply(): TimeType = new TimeType(MICROS_PRECISION)
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/HadoopCompressionCodec.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/HadoopCompressionCodec.java
index ee4cb4da322b8..df08413d8d3bb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/HadoopCompressionCodec.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/HadoopCompressionCodec.java
@@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.util;
import java.util.Arrays;
+import java.util.EnumMap;
import java.util.Locale;
-import java.util.Map;
import java.util.stream.Collectors;
import org.apache.hadoop.io.compress.BZip2Codec;
@@ -53,11 +53,15 @@ public CompressionCodec getCompressionCodec() {
return this.compressionCodec;
}
- private static final Map codecNameMap =
+ private static final EnumMap codecNameMap =
Arrays.stream(HadoopCompressionCodec.values()).collect(
- Collectors.toMap(Enum::name, codec -> codec.name().toLowerCase(Locale.ROOT)));
+ Collectors.toMap(
+ codec -> codec,
+ codec -> codec.name().toLowerCase(Locale.ROOT),
+ (oldValue, newValue) -> oldValue,
+ () -> new EnumMap<>(HadoopCompressionCodec.class)));
public String lowerCaseName() {
- return codecNameMap.get(this.name());
+ return codecNameMap.get(this);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DefaultValue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DefaultValue.java
new file mode 100644
index 0000000000000..6e487ce326a4a
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DefaultValue.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog;
+
+import java.util.Map;
+import java.util.Objects;
+import javax.annotation.Nullable;
+
+import org.apache.spark.SparkIllegalArgumentException;
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Expression;
+
+/**
+ * A class that represents default values.
+ *
+ * Connectors can define default values using either a SQL string (Spark SQL dialect) or an
+ * {@link Expression expression} if the default value can be expressed as a supported connector
+ * expression. If both the SQL string and the expression are provided, Spark first attempts to
+ * convert the given expression to its internal representation. If the expression cannot be
+ * converted, and a SQL string is provided, Spark will fall back to parsing the SQL string.
+ *
+ * @since 4.1.0
+ */
+@Evolving
+public class DefaultValue {
+ private final String sql;
+ private final Expression expr;
+
+ public DefaultValue(String sql) {
+ this(sql, null /* no expression */);
+ }
+
+ public DefaultValue(Expression expr) {
+ this(null /* no sql */, expr);
+ }
+
+ public DefaultValue(String sql, Expression expr) {
+ if (sql == null && expr == null) {
+ throw new SparkIllegalArgumentException(
+ "INTERNAL_ERROR",
+ Map.of("message", "SQL and expression can't be both null"));
+ }
+ this.sql = sql;
+ this.expr = expr;
+ }
+
+ /**
+ * Returns the SQL representation of the default value (Spark SQL dialect), if provided.
+ */
+ @Nullable
+ public String getSql() {
+ return sql;
+ }
+
+ /**
+ * Returns the expression representing the default value, if provided.
+ */
+ @Nullable
+ public Expression getExpression() {
+ return expr;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) return true;
+ if (other == null || getClass() != other.getClass()) return false;
+ DefaultValue that = (DefaultValue) other;
+ return Objects.equals(sql, that.sql) && Objects.equals(expr, that.expr);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(sql, expr);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("DefaultValue{sql=%s, expression=%s}", sql, expr);
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ProcedureCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ProcedureCatalog.java
index 6eaacf340cb80..ef799a6c30dfb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ProcedureCatalog.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ProcedureCatalog.java
@@ -34,4 +34,10 @@ public interface ProcedureCatalog extends CatalogPlugin {
* @return the loaded unbound procedure
*/
UnboundProcedure loadProcedure(Identifier ident);
+
+ /**
+ * List all procedures in the specified namespace.
+ *
+ */
+ Identifier[] listProcedures(String[] namespace);
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java
index f457a4a3d7863..6811ea380b3ae 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java
@@ -34,14 +34,14 @@
/**
* An optional mix-in for implementations of {@link TableCatalog} that support staging creation of
- * the a table before committing the table's metadata along with its contents in CREATE TABLE AS
+ * a table before committing the table's metadata along with its contents in CREATE TABLE AS
* SELECT or REPLACE TABLE AS SELECT operations.
*
* It is highly recommended to implement this trait whenever possible so that CREATE TABLE AS
* SELECT and REPLACE TABLE AS SELECT operations are atomic. For example, when one runs a REPLACE
* TABLE AS SELECT operation, if the catalog does not implement this trait, the planner will first
* drop the table via {@link TableCatalog#dropTable(Identifier)}, then create the table via
- * {@link TableCatalog#createTable(Identifier, Column[], Transform[], Map)}, and then perform
+ * {@link TableCatalog#createTable(Identifier, TableInfo)}, and then perform
* the write via {@link SupportsWrite#newWriteBuilder(LogicalWriteInfo)}.
* However, if the write operation fails, the catalog will have already dropped the table, and the
* planner cannot roll back the dropping of the table.
@@ -72,6 +72,21 @@ default StagedTable stageCreate(
throw QueryCompilationErrors.mustOverrideOneMethodError("stageCreate");
}
+ /**
+ * Stage the creation of a table, preparing it to be committed into the metastore.
+ *
+ * @deprecated This is deprecated. Please override
+ * {@link #stageCreate(Identifier, TableInfo)} instead.
+ */
+ @Deprecated(since = "4.1.0")
+ default StagedTable stageCreate(
+ Identifier ident,
+ Column[] columns,
+ Transform[] partitions,
+ Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException {
+ return stageCreate(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
+ }
+
/**
* Stage the creation of a table, preparing it to be committed into the metastore.
*
@@ -82,21 +97,16 @@ default StagedTable stageCreate(
* committed, an exception should be thrown by {@link StagedTable#commitStagedChanges()}.
*
* @param ident a table identifier
- * @param columns the column of the new table
- * @param partitions transforms to use for partitioning data in the table
- * @param properties a string map of table properties
+ * @param tableInfo information about the table
* @return metadata for the new table. This can be null if the catalog does not support atomic
* creation for this table. Spark will call {@link #loadTable(Identifier)} later.
* @throws TableAlreadyExistsException If a table or view already exists for the identifier
* @throws UnsupportedOperationException If a requested partition transform is not supported
* @throws NoSuchNamespaceException If the identifier namespace does not exist (optional)
*/
- default StagedTable stageCreate(
- Identifier ident,
- Column[] columns,
- Transform[] partitions,
- Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException {
- return stageCreate(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
+ default StagedTable stageCreate(Identifier ident, TableInfo tableInfo)
+ throws TableAlreadyExistsException, NoSuchNamespaceException {
+ return stageCreate(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties());
}
/**
@@ -115,6 +125,23 @@ default StagedTable stageReplace(
throw QueryCompilationErrors.mustOverrideOneMethodError("stageReplace");
}
+ /**
+ * Stage the replacement of a table, preparing it to be committed into the metastore when the
+ * returned table's {@link StagedTable#commitStagedChanges()} is called.
+ *
+ * This is deprecated, please override
+ * {@link #stageReplace(Identifier, TableInfo)} instead.
+ */
+ @Deprecated(since = "4.1.0")
+ default StagedTable stageReplace(
+ Identifier ident,
+ Column[] columns,
+ Transform[] partitions,
+ Map properties) throws NoSuchNamespaceException, NoSuchTableException {
+ return stageReplace(
+ ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
+ }
+
/**
* Stage the replacement of a table, preparing it to be committed into the metastore when the
* returned table's {@link StagedTable#commitStagedChanges()} is called.
@@ -134,22 +161,16 @@ default StagedTable stageReplace(
* operation.
*
* @param ident a table identifier
- * @param columns the columns of the new table
- * @param partitions transforms to use for partitioning data in the table
- * @param properties a string map of table properties
+ * @param tableInfo information about the table
* @return metadata for the new table. This can be null if the catalog does not support atomic
* creation for this table. Spark will call {@link #loadTable(Identifier)} later.
* @throws UnsupportedOperationException If a requested partition transform is not supported
* @throws NoSuchNamespaceException If the identifier namespace does not exist (optional)
* @throws NoSuchTableException If the table does not exist
*/
- default StagedTable stageReplace(
- Identifier ident,
- Column[] columns,
- Transform[] partitions,
- Map properties) throws NoSuchNamespaceException, NoSuchTableException {
- return stageReplace(
- ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
+ default StagedTable stageReplace(Identifier ident, TableInfo tableInfo)
+ throws NoSuchNamespaceException, NoSuchTableException {
+ return stageReplace(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties());
}
/**
@@ -168,6 +189,23 @@ default StagedTable stageCreateOrReplace(
throw QueryCompilationErrors.mustOverrideOneMethodError("stageCreateOrReplace");
}
+ /**
+ * Stage the creation or replacement of a table, preparing it to be committed into the metastore
+ * when the returned table's {@link StagedTable#commitStagedChanges()} is called.
+ *
+ * This is deprecated, please override
+ * {@link #stageCreateOrReplace(Identifier, TableInfo)} instead.
+ */
+ @Deprecated(since = "4.1.0")
+ default StagedTable stageCreateOrReplace(
+ Identifier ident,
+ Column[] columns,
+ Transform[] partitions,
+ Map properties) throws NoSuchNamespaceException {
+ return stageCreateOrReplace(
+ ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
+ }
+
/**
* Stage the creation or replacement of a table, preparing it to be committed into the metastore
* when the returned table's {@link StagedTable#commitStagedChanges()} is called.
@@ -186,21 +224,18 @@ default StagedTable stageCreateOrReplace(
* the staged changes are committed but the table doesn't exist at commit time.
*
* @param ident a table identifier
- * @param columns the columns of the new table
- * @param partitions transforms to use for partitioning data in the table
- * @param properties a string map of table properties
+ * @param tableInfo information about the table
* @return metadata for the new table. This can be null if the catalog does not support atomic
* creation for this table. Spark will call {@link #loadTable(Identifier)} later.
* @throws UnsupportedOperationException If a requested partition transform is not supported
* @throws NoSuchNamespaceException If the identifier namespace does not exist (optional)
*/
- default StagedTable stageCreateOrReplace(
- Identifier ident,
- Column[] columns,
- Transform[] partitions,
- Map properties) throws NoSuchNamespaceException {
- return stageCreateOrReplace(
- ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
+ default StagedTable stageCreateOrReplace(Identifier ident, TableInfo tableInfo)
+ throws NoSuchNamespaceException {
+ return stageCreateOrReplace(ident,
+ tableInfo.columns(),
+ tableInfo.partitions(),
+ tableInfo.properties());
}
/**
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java
index d5eb03dcf94d4..f9a75ccd1c8da 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java
@@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.catalog;
import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.catalog.constraints.Constraint;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.types.StructType;
@@ -83,4 +84,15 @@ default Map properties() {
* Returns the set of capabilities for this table.
*/
Set capabilities();
+
+ /**
+ * Returns the constraints for this table.
+ */
+ default Constraint[] constraints() { return new Constraint[0]; }
+
+ /**
+ * Returns the current table version if implementation supports versioning.
+ * If the table is not versioned, null can be returned here.
+ */
+ default String currentVersion() { return null; }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java
index 77dbaa7687b41..f2cbafbe8e5bd 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java
@@ -33,11 +33,11 @@
/**
* Catalog methods for working with Tables.
*
- * TableCatalog implementations may be case sensitive or case insensitive. Spark will pass
+ * TableCatalog implementations may be case-sensitive or case-insensitive. Spark will pass
* {@link Identifier table identifiers} without modification. Field names passed to
* {@link #alterTable(Identifier, TableChange...)} will be normalized to match the case used in the
- * table schema when updating, renaming, or dropping existing columns when catalyst analysis is case
- * insensitive.
+ * table schema when updating, renaming, or dropping existing columns when catalyst analysis is
+ * case-insensitive.
*
* @since 3.0.0
*/
@@ -208,26 +208,37 @@ default Table createTable(
throw QueryCompilationErrors.mustOverrideOneMethodError("createTable");
}
+ /**
+ * Create a table in the catalog.
+ *
+ * @deprecated This is deprecated. Please override
+ * {@link #createTable(Identifier, TableInfo)} instead.
+ */
+ @Deprecated(since = "4.1.0")
+ default Table createTable(
+ Identifier ident,
+ Column[] columns,
+ Transform[] partitions,
+ Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException {
+ return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
+ }
+
/**
* Create a table in the catalog.
*
* @param ident a table identifier
- * @param columns the columns of the new table.
- * @param partitions transforms to use for partitioning data in the table
- * @param properties a string map of table properties
+ * @param tableInfo information about the table.
* @return metadata for the new table. This can be null if getting the metadata for the new table
* is expensive. Spark will call {@link #loadTable(Identifier)} if needed (e.g. CTAS).
*
* @throws TableAlreadyExistsException If a table or view already exists for the identifier
* @throws UnsupportedOperationException If a requested partition transform is not supported
* @throws NoSuchNamespaceException If the identifier namespace does not exist (optional)
+ * @since 4.1.0
*/
- default Table createTable(
- Identifier ident,
- Column[] columns,
- Transform[] partitions,
- Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException {
- return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
+ default Table createTable(Identifier ident, TableInfo tableInfo)
+ throws TableAlreadyExistsException, NoSuchNamespaceException {
+ return createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties());
}
/**
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java
index dceac1b484cf2..a60c827d5ace1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java
@@ -61,6 +61,21 @@ public enum TableCatalogCapability {
*/
SUPPORT_COLUMN_DEFAULT_VALUE,
+ /**
+ * Signals that the TableCatalog supports defining table constraints in
+ * CREATE/REPLACE/ALTER TABLE.
+ *
+ * Without this capability, any CREATE/REPLACE/ALTER TABLE statement with table constraints
+ * defined in the table schema will throw an exception during analysis.
+ *
+ * Table constraints include CHECK, PRIMARY KEY, UNIQUE and FOREIGN KEY constraints.
+ *
+ * Table constraints are included in the table schema for APIs like
+ * {@link TableCatalog#createTable}.
+ * See {@link Table#constraints()}.
+ */
+ SUPPORT_TABLE_CONSTRAINT,
+
/**
* Signals that the TableCatalog supports defining identity columns upon table creation in SQL.
*
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java
index d7a51c519e09b..a53962f8f3008 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java
@@ -22,6 +22,7 @@
import javax.annotation.Nullable;
import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.catalog.constraints.Constraint;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.types.DataType;
@@ -297,6 +298,21 @@ public int hashCode() {
}
}
+ /**
+ * Create a TableChange for adding a new table constraint
+ */
+ static TableChange addConstraint(Constraint constraint, String validatedTableVersion) {
+ return new AddConstraint(constraint, validatedTableVersion);
+ }
+
+ /**
+ * Create a TableChange for dropping a table constraint
+ */
+ static TableChange dropConstraint(String name, boolean ifExists, boolean cascade) {
+ DropConstraint.Mode mode = cascade ? DropConstraint.Mode.CASCADE : DropConstraint.Mode.RESTRICT;
+ return new DropConstraint(name, ifExists, mode);
+ }
+
/**
* A TableChange to remove a table property.
*
@@ -787,4 +803,82 @@ public int hashCode() {
return Arrays.hashCode(clusteringColumns);
}
}
+
+ /** A TableChange to alter table and add a constraint. */
+ final class AddConstraint implements TableChange {
+ private final Constraint constraint;
+ private final String validatedTableVersion;
+
+ private AddConstraint(Constraint constraint, String validatedTableVersion) {
+ this.constraint = constraint;
+ this.validatedTableVersion = validatedTableVersion;
+ }
+
+ public Constraint constraint() {
+ return constraint;
+ }
+
+ public String validatedTableVersion() {
+ return validatedTableVersion;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AddConstraint that = (AddConstraint) o;
+ return constraint.equals(that.constraint) &&
+ Objects.equals(validatedTableVersion, that.validatedTableVersion);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(constraint, validatedTableVersion);
+ }
+ }
+
+ /** A TableChange to alter table and drop a constraint. */
+ final class DropConstraint implements TableChange {
+ private final String name;
+ private final boolean ifExists;
+ private final Mode mode;
+
+ /**
+ * Defines modes for dropping a constraint.
+ *
+ * RESTRICT - Prevents dropping a constraint if it is referenced by other objects.
+ * CASCADE - Automatically drops objects that depend on the constraint.
+ */
+ public enum Mode { RESTRICT, CASCADE }
+
+ private DropConstraint(String name, boolean ifExists, Mode mode) {
+ this.name = name;
+ this.ifExists = ifExists;
+ this.mode = mode;
+ }
+
+ public String name() {
+ return name;
+ }
+
+ public boolean ifExists() {
+ return ifExists;
+ }
+
+ public Mode mode() {
+ return mode;
+ }
+
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ DropConstraint that = (DropConstraint) o;
+ return that.name.equals(name) && that.ifExists == ifExists && mode == that.mode;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name, ifExists, mode);
+ }
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableInfo.java
new file mode 100644
index 0000000000000..8dc71c5aee472
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableInfo.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connector.catalog;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+import com.google.common.collect.Maps;
+import org.apache.spark.sql.connector.catalog.constraints.Constraint;
+import org.apache.spark.sql.connector.expressions.Transform;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.Map;
+
+public class TableInfo {
+
+ private final Column[] columns;
+ private final Map properties;
+ private final Transform[] partitions;
+ private final Constraint[] constraints;
+
+ /**
+ * Constructor for TableInfo used by the builder.
+ * @param builder Builder.
+ */
+ private TableInfo(Builder builder) {
+ this.columns = builder.columns;
+ this.properties = builder.properties;
+ this.partitions = builder.partitions;
+ this.constraints = builder.constraints;
+ }
+
+ public Column[] columns() {
+ return columns;
+ }
+
+ public StructType schema() {
+ return CatalogV2Util.v2ColumnsToStructType(columns);
+ }
+
+ public Map properties() {
+ return properties;
+ }
+
+ public Transform[] partitions() {
+ return partitions;
+ }
+
+ public Constraint[] constraints() { return constraints; }
+
+ public static class Builder {
+ private Column[] columns;
+ private Map properties = Maps.newHashMap();
+ private Transform[] partitions = new Transform[0];
+ private Constraint[] constraints = new Constraint[0];
+
+ public Builder withColumns(Column[] columns) {
+ this.columns = columns;
+ return this;
+ }
+
+ public Builder withProperties(Map properties) {
+ this.properties = properties;
+ return this;
+ }
+
+ public Builder withPartitions(Transform[] partitions) {
+ this.partitions = partitions;
+ return this;
+ }
+
+ public Builder withConstraints(Constraint[] constraints) {
+ this.constraints = constraints;
+ return this;
+ }
+
+ public TableInfo build() {
+ checkNotNull(columns, "columns should not be null");
+ return new TableInfo(this);
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java
new file mode 100644
index 0000000000000..28791a9f3a58f
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/BaseConstraint.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.constraints;
+
+import java.util.StringJoiner;
+
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+abstract class BaseConstraint implements Constraint {
+
+ private final String name;
+ private final boolean enforced;
+ private final ValidationStatus validationStatus;
+ private final boolean rely;
+
+ protected BaseConstraint(
+ String name,
+ boolean enforced,
+ ValidationStatus validationStatus,
+ boolean rely) {
+ this.name = name;
+ this.enforced = enforced;
+ this.validationStatus = validationStatus;
+ this.rely = rely;
+ }
+
+ protected abstract String definition();
+
+ @Override
+ public String name() {
+ return name;
+ }
+
+ @Override
+ public boolean enforced() {
+ return enforced;
+ }
+
+ @Override
+ public ValidationStatus validationStatus() {
+ return validationStatus;
+ }
+
+ @Override
+ public boolean rely() {
+ return rely;
+ }
+
+ @Override
+ public String toDDL() {
+ return String.format(
+ "CONSTRAINT %s %s %s %s %s",
+ name,
+ definition(),
+ enforced ? "ENFORCED" : "NOT ENFORCED",
+ validationStatus,
+ rely ? "RELY" : "NORELY");
+ }
+
+ @Override
+ public String toString() {
+ return toDDL();
+ }
+
+ protected String toDDL(NamedReference[] columns) {
+ StringJoiner joiner = new StringJoiner(", ");
+
+ for (NamedReference column : columns) {
+ joiner.add(column.toString());
+ }
+
+ return joiner.toString();
+ }
+
+ abstract static class Builder {
+ private final String name;
+ private boolean enforced = true;
+ private ValidationStatus validationStatus = ValidationStatus.UNVALIDATED;
+ private boolean rely = false;
+
+ Builder(String name) {
+ this.name = name;
+ }
+
+ protected abstract B self();
+
+ public abstract C build();
+
+ public String name() {
+ return name;
+ }
+
+ public B enforced(boolean enforced) {
+ this.enforced = enforced;
+ return self();
+ }
+
+ public boolean enforced() {
+ return enforced;
+ }
+
+ public B validationStatus(ValidationStatus validationStatus) {
+ if (validationStatus != null) {
+ this.validationStatus = validationStatus;
+ }
+ return self();
+ }
+
+ public ValidationStatus validationStatus() {
+ return validationStatus;
+ }
+
+ public B rely(boolean rely) {
+ this.rely = rely;
+ return self();
+ }
+
+ public boolean rely() {
+ return rely;
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java
new file mode 100644
index 0000000000000..ae005d946694a
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.constraints;
+
+import java.util.Map;
+import java.util.Objects;
+
+import org.apache.spark.SparkIllegalArgumentException;
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+
+/**
+ * A CHECK constraint.
+ *
+ * A CHECK constraint defines a condition each row in a table must satisfy. Connectors can define
+ * such constraints either in SQL (Spark SQL dialect) or using a {@link Predicate predicate} if the
+ * condition can be expressed using a supported expression. A CHECK constraint can reference one or
+ * more columns. Such constraint is considered violated if its condition evaluates to {@code FALSE},
+ * but not {@code NULL}. The search condition must be deterministic and cannot contain subqueries
+ * and certain functions like aggregates or UDFs.
+ *
+ * Spark supports enforced and not enforced CHECK constraints, allowing connectors to control
+ * whether data modifications that violate the constraint must fail. Each constraint is either
+ * valid (the existing data is guaranteed to satisfy the constraint), invalid (some records violate
+ * the constraint), or unvalidated (the validity is unknown). If the validity is unknown, Spark
+ * will check {@link #rely()} to see whether the constraint is believed to be true and can be used
+ * for query optimization.
+ *
+ * @since 4.1.0
+ */
+@Evolving
+public class Check extends BaseConstraint {
+
+ private final String predicateSql;
+ private final Predicate predicate;
+
+ private Check(
+ String name,
+ String predicateSql,
+ Predicate predicate,
+ boolean enforced,
+ ValidationStatus validationStatus,
+ boolean rely) {
+ super(name, enforced, validationStatus, rely);
+ this.predicateSql = predicateSql;
+ this.predicate = predicate;
+ }
+
+ /**
+ * Returns the SQL representation of the search condition (Spark SQL dialect).
+ */
+ public String predicateSql() {
+ return predicateSql;
+ }
+
+ /**
+ * Returns the search condition.
+ */
+ public Predicate predicate() {
+ return predicate;
+ }
+
+ @Override
+ protected String definition() {
+ return String.format("CHECK (%s)", predicateSql != null ? predicateSql : predicate);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) return true;
+ if (other == null || getClass() != other.getClass()) return false;
+ Check that = (Check) other;
+ return Objects.equals(name(), that.name()) &&
+ Objects.equals(predicateSql, that.predicateSql) &&
+ Objects.equals(predicate, that.predicate) &&
+ enforced() == that.enforced() &&
+ Objects.equals(validationStatus(), that.validationStatus()) &&
+ rely() == that.rely();
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name(), predicateSql, predicate, enforced(), validationStatus(), rely());
+ }
+
+ public static class Builder extends BaseConstraint.Builder {
+
+ private String predicateSql;
+ private Predicate predicate;
+
+ Builder(String name) {
+ super(name);
+ }
+
+ @Override
+ protected Builder self() {
+ return this;
+ }
+
+ public Builder predicateSql(String predicateSql) {
+ this.predicateSql = predicateSql;
+ return this;
+ }
+
+ public Builder predicate(Predicate predicate) {
+ this.predicate = predicate;
+ return this;
+ }
+
+ public Check build() {
+ if (predicateSql == null && predicate == null) {
+ throw new SparkIllegalArgumentException(
+ "INTERNAL_ERROR",
+ Map.of("message", "Predicate SQL and expression can't be both null in CHECK"));
+ }
+ return new Check(name(), predicateSql, predicate, enforced(), validationStatus(), rely());
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java
new file mode 100644
index 0000000000000..c3a2cd73e9abe
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Constraint.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.constraints;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * A constraint that restricts states of data in a table.
+ *
+ * @since 4.1.0
+ */
+@Evolving
+public interface Constraint {
+ /**
+ * Returns the name of this constraint.
+ */
+ String name();
+
+ /**
+ * Indicates whether this constraint is actively enforced. If enforced, data modifications
+ * that violate the constraint fail with a constraint violation error.
+ */
+ boolean enforced();
+
+ /**
+ * Indicates whether the existing data in the table satisfies this constraint. The constraint
+ * can be valid (the data is guaranteed to satisfy the constraint), invalid (some records violate
+ * the constraint), or unvalidated (the validity is unknown). The validation status is usually
+ * managed by the system and can't be modified by the user.
+ */
+ ValidationStatus validationStatus();
+
+ /**
+ * Indicates whether this constraint is assumed to hold true if the validity is unknown. Unlike
+ * the validation status, this flag is usually provided by the user as a hint to the system.
+ */
+ boolean rely();
+
+ /**
+ * Returns the definition of this constraint in the DDL format.
+ */
+ String toDDL();
+
+ /**
+ * Instantiates a builder for a CHECK constraint.
+ *
+ * @param name the constraint name
+ * @return a CHECK constraint builder
+ */
+ static Check.Builder check(String name) {
+ return new Check.Builder(name);
+ }
+
+ /**
+ * Instantiates a builder for a UNIQUE constraint.
+ *
+ * @param name the constraint name
+ * @param columns columns that comprise the unique key
+ * @return a UNIQUE constraint builder
+ */
+ static Unique.Builder unique(String name, NamedReference[] columns) {
+ return new Unique.Builder(name, columns);
+ }
+
+ /**
+ * Instantiates a builder for a PRIMARY KEY constraint.
+ *
+ * @param name the constraint name
+ * @param columns columns that comprise the primary key
+ * @return a PRIMARY KEY constraint builder
+ */
+ static PrimaryKey.Builder primaryKey(String name, NamedReference[] columns) {
+ return new PrimaryKey.Builder(name, columns);
+ }
+
+ /**
+ * Instantiates a builder for a FOREIGN KEY constraint.
+ *
+ * @param name the constraint name
+ * @param columns the referencing columns
+ * @param refTable the referenced table identifier
+ * @param refColumns the referenced columns in the referenced table
+ * @return a FOREIGN KEY constraint builder
+ */
+ static ForeignKey.Builder foreignKey(
+ String name,
+ NamedReference[] columns,
+ Identifier refTable,
+ NamedReference[] refColumns) {
+ return new ForeignKey.Builder(name, columns, refTable, refColumns);
+ }
+
+ /**
+ * An indicator of the validity of the constraint.
+ *
+ * A constraint may be validated independently of enforcement, meaning it can be validated
+ * without being actively enforced, or vice versa. A constraint can be valid (the data is
+ * guaranteed to satisfy the constraint), invalid (some records violate the constraint),
+ * or unvalidated (the validity is unknown).
+ */
+ enum ValidationStatus {
+ VALID, INVALID, UNVALIDATED
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java
new file mode 100644
index 0000000000000..cb1a441688e99
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/ForeignKey.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.constraints;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * A FOREIGN KEY constraint.
+ *
+ * A FOREIGN KEY constraint specifies one or more columns (referencing columns) in a table that
+ * refer to corresponding columns (referenced columns) in another table. The referenced columns
+ * must form a UNIQUE or PRIMARY KEY constraint in the referenced table. For this constraint to be
+ * satisfied, each row in the table must contain values in the referencing columns that exactly
+ * match values of a row in the referenced table.
+ *
+ * Spark doesn't enforce FOREIGN KEY constraints but leverages them for query optimization. Each
+ * constraint is either valid (the existing data is guaranteed to satisfy the constraint), invalid
+ * (some records violate the constraint), or unvalidated (the validity is unknown). If the validity
+ * is unknown, Spark will check {@link #rely()} to see whether the constraint is believed to be
+ * true and can be used for query optimization.
+ *
+ * @since 4.1.0
+ */
+@Evolving
+public class ForeignKey extends BaseConstraint {
+
+ private final NamedReference[] columns;
+ private final Identifier refTable;
+ private final NamedReference[] refColumns;
+
+ ForeignKey(
+ String name,
+ NamedReference[] columns,
+ Identifier refTable,
+ NamedReference[] refColumns,
+ boolean enforced,
+ ValidationStatus validationStatus,
+ boolean rely) {
+ super(name, enforced, validationStatus, rely);
+ this.columns = columns;
+ this.refTable = refTable;
+ this.refColumns = refColumns;
+ }
+
+ /**
+ * Returns the referencing columns.
+ */
+ public NamedReference[] columns() {
+ return columns;
+ }
+
+ /**
+ * Returns the referenced table.
+ */
+ public Identifier referencedTable() {
+ return refTable;
+ }
+
+ /**
+ * Returns the referenced columns in the referenced table.
+ */
+ public NamedReference[] referencedColumns() {
+ return refColumns;
+ }
+
+ @Override
+ protected String definition() {
+ return String.format(
+ "FOREIGN KEY (%s) REFERENCES %s (%s)",
+ toDDL(columns),
+ refTable,
+ toDDL(refColumns));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) return true;
+ if (other == null || getClass() != other.getClass()) return false;
+ ForeignKey that = (ForeignKey) other;
+ return Objects.equals(name(), that.name()) &&
+ Arrays.equals(columns, that.columns) &&
+ Objects.equals(refTable, that.refTable) &&
+ Arrays.equals(refColumns, that.refColumns) &&
+ enforced() == that.enforced() &&
+ Objects.equals(validationStatus(), that.validationStatus()) &&
+ rely() == that.rely();
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Objects.hash(name(), refTable, enforced(), validationStatus(), rely());
+ result = 31 * result + Arrays.hashCode(columns);
+ result = 31 * result + Arrays.hashCode(refColumns);
+ return result;
+ }
+
+ public static class Builder extends BaseConstraint.Builder {
+
+ private final NamedReference[] columns;
+ private final Identifier refTable;
+ private final NamedReference[] refColumns;
+
+ public Builder(
+ String name,
+ NamedReference[] columns,
+ Identifier refTable,
+ NamedReference[] refColumns) {
+ super(name);
+ this.columns = columns;
+ this.refTable = refTable;
+ this.refColumns = refColumns;
+ }
+
+ @Override
+ protected Builder self() {
+ return this;
+ }
+
+ @Override
+ public ForeignKey build() {
+ return new ForeignKey(
+ name(),
+ columns,
+ refTable,
+ refColumns,
+ enforced(),
+ validationStatus(),
+ rely());
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java
new file mode 100644
index 0000000000000..31950b85bb8db
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/PrimaryKey.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.constraints;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * A PRIMARY KEY constraint.
+ *
+ * A PRIMARY KEY constraint specifies ore or more columns as a primary key. Such constraint is
+ * satisfied if and only if no two rows in a table have the same non-null values in the primary
+ * key columns and none of the values in the specified column or columns are {@code NULL}.
+ * A table can have at most one primary key.
+ *
+ * Spark doesn't enforce PRIMARY KEY constraints but leverages them for query optimization. Each
+ * constraint is either valid (the existing data is guaranteed to satisfy the constraint), invalid
+ * (some records violate the constraint), or unvalidated (the validity is unknown). If the validity
+ * is unknown, Spark will check {@link #rely()} to see whether the constraint is believed to be
+ * true and can be used for query optimization.
+ *
+ * @since 4.1.0
+ */
+@Evolving
+public class PrimaryKey extends BaseConstraint {
+
+ private final NamedReference[] columns;
+
+ PrimaryKey(
+ String name,
+ NamedReference[] columns,
+ boolean enforced,
+ ValidationStatus validationStatus,
+ boolean rely) {
+ super(name, enforced, validationStatus, rely);
+ this.columns = columns;
+ }
+
+ /**
+ * Returns the columns that comprise the primary key.
+ */
+ public NamedReference[] columns() {
+ return columns;
+ }
+
+ @Override
+ protected String definition() {
+ return String.format("PRIMARY KEY (%s)", toDDL(columns));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) return true;
+ if (other == null || getClass() != other.getClass()) return false;
+ PrimaryKey that = (PrimaryKey) other;
+ return Objects.equals(name(), that.name()) &&
+ Arrays.equals(columns, that.columns()) &&
+ enforced() == that.enforced() &&
+ Objects.equals(validationStatus(), that.validationStatus()) &&
+ rely() == that.rely();
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Objects.hash(name(), enforced(), validationStatus(), rely());
+ result = 31 * result + Arrays.hashCode(columns);
+ return result;
+ }
+
+ public static class Builder extends BaseConstraint.Builder {
+
+ private final NamedReference[] columns;
+
+ Builder(String name, NamedReference[] columns) {
+ super(name);
+ this.columns = columns;
+ }
+
+ @Override
+ protected Builder self() {
+ return this;
+ }
+
+ @Override
+ public PrimaryKey build() {
+ return new PrimaryKey(name(), columns, enforced(), validationStatus(), rely());
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java
new file mode 100644
index 0000000000000..d983ef656297e
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.constraints;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * A UNIQUE constraint.
+ *
+ * A UNIQUE constraint specifies one or more columns as unique columns. Such constraint is satisfied
+ * if and only if no two rows in a table have the same non-null values in the unique columns.
+ *
+ * Spark doesn't enforce UNIQUE constraints but leverages them for query optimization. Each
+ * constraint is either valid (the existing data is guaranteed to satisfy the constraint), invalid
+ * (some records violate the constraint), or unvalidated (the validity is unknown). If the validity
+ * is unknown, Spark will check {@link #rely()} to see whether the constraint is believed to be
+ * true and can be used for query optimization.
+ *
+ * @since 4.1.0
+ */
+@Evolving
+public class Unique extends BaseConstraint {
+
+ private final NamedReference[] columns;
+
+ private Unique(
+ String name,
+ NamedReference[] columns,
+ boolean enforced,
+ ValidationStatus validationStatus,
+ boolean rely) {
+ super(name, enforced, validationStatus, rely);
+ this.columns = columns;
+ }
+
+ /**
+ * Returns the columns that comprise the unique key.
+ */
+ public NamedReference[] columns() {
+ return columns;
+ }
+
+ @Override
+ protected String definition() {
+ return String.format("UNIQUE (%s)", toDDL(columns));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) return true;
+ if (other == null || getClass() != other.getClass()) return false;
+ Unique that = (Unique) other;
+ return Objects.equals(name(), that.name()) &&
+ Arrays.equals(columns, that.columns()) &&
+ enforced() == that.enforced() &&
+ Objects.equals(validationStatus(), that.validationStatus()) &&
+ rely() == that.rely();
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Objects.hash(name(), enforced(), validationStatus(), rely());
+ result = 31 * result + Arrays.hashCode(columns);
+ return result;
+ }
+
+ public static class Builder extends BaseConstraint.Builder {
+
+ private final NamedReference[] columns;
+
+ Builder(String name, NamedReference[] columns) {
+ super(name);
+ this.columns = columns;
+ }
+
+ @Override
+ protected Builder self() {
+ return this;
+ }
+
+ public Unique build() {
+ return new Unique(name(), columns, enforced(), validationStatus(), rely());
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
index 18c76833c5879..3d837be366f7f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
@@ -20,6 +20,8 @@
import javax.annotation.Nullable;
import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.catalog.DefaultValue;
+import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.internal.connector.ProcedureParameterImpl;
import org.apache.spark.sql.types.DataType;
@@ -68,7 +70,7 @@ static Builder in(String name, DataType dataType) {
* null if not provided.
*/
@Nullable
- String defaultValueExpression();
+ DefaultValue defaultValue();
/**
* Returns the comment of this parameter or null if not provided.
@@ -89,7 +91,7 @@ class Builder {
private final Mode mode;
private final String name;
private final DataType dataType;
- private String defaultValueExpression;
+ private DefaultValue defaultValue;
private String comment;
private Builder(Mode mode, String name, DataType dataType) {
@@ -99,10 +101,26 @@ private Builder(Mode mode, String name, DataType dataType) {
}
/**
- * Sets the default value expression of the parameter.
+ * Sets the default value of the parameter using SQL.
*/
- public Builder defaultValue(String defaultValueExpression) {
- this.defaultValueExpression = defaultValueExpression;
+ public Builder defaultValue(String sql) {
+ this.defaultValue = new DefaultValue(sql);
+ return this;
+ }
+
+ /**
+ * Sets the default value of the parameter using an expression.
+ */
+ public Builder defaultValue(Expression expression) {
+ this.defaultValue = new DefaultValue(expression);
+ return this;
+ }
+
+ /**
+ * Sets the default value of the parameter.
+ */
+ public Builder defaultValue(DefaultValue defaultValue) {
+ this.defaultValue = defaultValue;
return this;
}
@@ -118,7 +136,7 @@ public Builder comment(String comment) {
* Builds the stored procedure parameter.
*/
public ProcedureParameter build() {
- return new ProcedureParameterImpl(mode, name, dataType, defaultValueExpression, comment);
+ return new ProcedureParameterImpl(mode, name, dataType, defaultValue, comment);
}
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index cb132ab11326d..4298f31227500 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -80,7 +80,7 @@ public String build(Expression expr) {
} else if (expr instanceof Cast cast) {
return visitCast(build(cast.expression()), cast.expressionDataType(), cast.dataType());
} else if (expr instanceof Extract extract) {
- return visitExtract(extract.field(), build(extract.source()));
+ return visitExtract(extract);
} else if (expr instanceof SortOrder sortOrder) {
return visitSortOrder(
build(sortOrder.expression()), sortOrder.direction(), sortOrder.nullOrdering());
@@ -119,7 +119,7 @@ yield visitBinaryArithmetic(
"RADIANS", "SIGN", "WIDTH_BUCKET", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE",
"DATE_ADD", "DATE_DIFF", "TRUNC", "AES_ENCRYPT", "AES_DECRYPT", "SHA1", "SHA2", "MD5",
"CRC32", "BIT_LENGTH", "CHAR_LENGTH", "CONCAT", "RPAD", "LPAD" ->
- visitSQLFunction(name, expressionsToStringArray(e.children()));
+ visitSQLFunction(name, e.children());
case "CASE_WHEN" -> visitCaseWhen(expressionsToStringArray(e.children()));
case "TRIM" -> visitTrim("BOTH", expressionsToStringArray(e.children()));
case "LTRIM" -> visitTrim("LEADING", expressionsToStringArray(e.children()));
@@ -147,8 +147,7 @@ yield visitBinaryArithmetic(
expressionsToStringArray(avg.children()));
} else if (expr instanceof GeneralAggregateFunc f) {
if (f.orderingWithinGroups().length == 0) {
- return visitAggregateFunction(f.name(), f.isDistinct(),
- expressionsToStringArray(f.children()));
+ return visitAggregateFunction(f.name(), f.isDistinct(), f.children());
} else {
return visitInverseDistributionFunction(
f.name(),
@@ -273,12 +272,20 @@ protected String visitCaseWhen(String[] children) {
return sb.toString();
}
+ protected String visitSQLFunction(String funcName, Expression[] inputs) {
+ return visitSQLFunction(funcName, expressionsToStringArray(inputs));
+ }
+
protected String visitSQLFunction(String funcName, String[] inputs) {
return joinArrayToString(inputs, ", ", funcName + "(", ")");
}
protected String visitAggregateFunction(
- String funcName, boolean isDistinct, String[] inputs) {
+ String funcName, boolean isDistinct, Expression[] inputs) {
+ return visitAggregateFunction(funcName, isDistinct, expressionsToStringArray(inputs));
+ }
+
+ protected String visitAggregateFunction(String funcName, boolean isDistinct, String[] inputs) {
if (isDistinct) {
return joinArrayToString(inputs, ", ", funcName + "(DISTINCT ", ")");
} else {
@@ -333,6 +340,10 @@ protected String visitTrim(String direction, String[] inputs) {
}
}
+ protected String visitExtract(Extract extract) {
+ return visitExtract(extract.field(), build(extract.source()));
+ }
+
protected String visitExtract(String field, String source) {
return "EXTRACT(" + field + " FROM " + source + ")";
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index fab65251ed51b..ef4308beafe86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable
@@ -71,6 +71,7 @@ object CatalystTypeConverters {
case _: StringType => StringConverter
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
case DateType => DateConverter
+ case _: TimeType => TimeConverter
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter
case TimestampType => TimestampConverter
case TimestampNTZType => TimestampNTZConverter
@@ -372,6 +373,18 @@ object CatalystTypeConverters {
DateTimeUtils.daysToLocalDate(row.getInt(column))
}
+ private object TimeConverter extends CatalystTypeConverter[LocalTime, LocalTime, Any] {
+ override def toCatalystImpl(scalaValue: LocalTime): Long = {
+ DateTimeUtils.localTimeToMicros(scalaValue)
+ }
+ override def toScala(catalystValue: Any): LocalTime = {
+ if (catalystValue == null) null
+ else DateTimeUtils.microsToLocalTime(catalystValue.asInstanceOf[Long])
+ }
+ override def toScalaImpl(row: InternalRow, column: Int): LocalTime =
+ DateTimeUtils.microsToLocalTime(row.getLong(column))
+ }
+
private object TimestampConverter extends CatalystTypeConverter[Any, Timestamp, Any] {
override def toCatalystImpl(scalaValue: Any): Long = scalaValue match {
case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
@@ -558,6 +571,7 @@ object CatalystTypeConverters {
case c: Char => StringConverter.toCatalyst(c.toString)
case d: Date => DateConverter.toCatalyst(d)
case ld: LocalDate => LocalDateConverter.toCatalyst(ld)
+ case t: LocalTime => TimeConverter.toCatalyst(t)
case t: Timestamp => TimestampConverter.toCatalyst(t)
case i: Instant => InstantConverter.toCatalyst(i)
case l: LocalDateTime => TimestampNTZConverter.toCatalyst(l)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala
index 5348d1054d5d4..aa5e3c1de13f7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.catalyst
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
/**
* Interface defines the following methods for a data source:
* - register a new option name
@@ -64,3 +69,46 @@ trait DataSourceOptions {
*/
def getAlternativeOption(name: String): Option[String] = validOptions.get(name).flatten
}
+
+object DataSourceOptions {
+ // The common option name for all data sources that supports single-variant-column parsing mode.
+ // The option should take in a column name and specifies that the entire record should be stored
+ // as a single VARIANT type column in the table with the given column name.
+ // E.g. spark.read.format("").option("singleVariantColumn", "colName")
+ val SINGLE_VARIANT_COLUMN = "singleVariantColumn"
+ // The common option name for all data sources that supports corrupt record. In case of a parsing
+ // error, the record will be stored as a string in the column with the given name.
+ // Theoretically, the behavior of this option is not affected by the parsing mode
+ // (PERMISSIVE/FAILFAST/DROPMALFORMED). However, the corrupt record is only visible to the user
+ // when in PERMISSIVE mode, because the queries will fail in FAILFAST mode, or the row containing
+ // the corrupt record will be dropped in DROPMALFORMED mode.
+ val COLUMN_NAME_OF_CORRUPT_RECORD = "columnNameOfCorruptRecord"
+
+ // When `singleVariantColumn` is enabled and there is a user-specified schema, the schema must
+ // either be a variant field, or a variant field plus a corrupt column field.
+ def validateSingleVariantColumn(
+ options: CaseInsensitiveMap[String],
+ userSpecifiedSchema: Option[StructType]): Unit = {
+ (options.get(SINGLE_VARIANT_COLUMN), userSpecifiedSchema) match {
+ case (Some(variantColumnName), Some(schema)) =>
+ var valid = schema.fields.exists { f =>
+ f.dataType.isInstanceOf[VariantType] && f.name == variantColumnName && f.nullable
+ }
+ schema.length match {
+ case 1 =>
+ case 2 =>
+ val corruptRecordColumnName = options.getOrElse(
+ COLUMN_NAME_OF_CORRUPT_RECORD, SQLConf.get.columnNameOfCorruptRecord)
+ valid = valid && corruptRecordColumnName != variantColumnName
+ valid = valid && schema.fields.exists { f =>
+ f.dataType.isInstanceOf[StringType] && f.name == corruptRecordColumnName && f.nullable
+ }
+ case _ => valid = false
+ }
+ if (!valid) {
+ throw QueryCompilationErrors.invalidSingleVariantColumn(schema)
+ }
+ case _ =>
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index fc477d1bc5ef5..9b22f28ed12da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
@@ -156,6 +156,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}
+ def createDeserializerForLocalTime(path: Expression): Expression = {
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ ObjectType(classOf[java.time.LocalTime]),
+ "microsToLocalTime",
+ path :: Nil,
+ returnNullable = false)
+ }
+
def createDeserializerForJavaBigDecimal(
path: Expression,
returnNullable: Boolean): Expression = {
@@ -270,6 +279,8 @@ object DeserializerBuildHelper {
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath): Expression = enc match {
+ case ae: AgnosticExpressionPathEncoder[_] =>
+ ae.fromCatalyst(path)
case _ if isNativeEncoder(enc) =>
path
case _: BoxedLeafEncoder[_, _] =>
@@ -312,6 +323,8 @@ object DeserializerBuildHelper {
createDeserializerForInstant(path)
case LocalDateTimeEncoder =>
createDeserializerForLocalDateTime(path)
+ case LocalTimeEncoder =>
+ createDeserializerForLocalTime(path)
case UDTEncoder(udt, udtClass) =>
val obj = NewInstance(udtClass, Nil, ObjectType(udtClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
@@ -447,13 +460,13 @@ object DeserializerBuildHelper {
val result = InitializeJavaBean(newInstance, setters.toMap)
exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)
- case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec =>
+ case TransformingEncoder(tag, _, codec, _) if codec == JavaSerializationCodec =>
DecodeUsingSerializer(path, tag, kryo = false)
- case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec =>
+ case TransformingEncoder(tag, _, codec, _) if codec == KryoSerializationCodec =>
DecodeUsingSerializer(path, tag, kryo = true)
- case TransformingEncoder(tag, encoder, provider) =>
+ case TransformingEncoder(tag, encoder, provider, _) =>
Invoke(
Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
"decode",
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 089d463ecacbb..c8bf1f5237997 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -21,8 +21,8 @@ import scala.language.existentials
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
@@ -99,6 +99,15 @@ object SerializerBuildHelper {
returnNullable = false)
}
+ def createSerializerForLocalTime(inputObject: Expression): Expression = {
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ TimeType(),
+ "localTimeToMicros",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
def createSerializerForScalaEnum(inputObject: Expression): Expression = {
createSerializerForString(
Invoke(
@@ -306,6 +315,7 @@ object SerializerBuildHelper {
* by encoder `enc`.
*/
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
+ case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input)
case _ if isNativeEncoder(enc) => input
case BoxedBooleanEncoder => createSerializerForBoolean(input)
case BoxedByteEncoder => createSerializerForByte(input)
@@ -333,6 +343,7 @@ object SerializerBuildHelper {
case TimestampEncoder(false) => createSerializerForSqlTimestamp(input)
case InstantEncoder(false) => createSerializerForJavaInstant(input)
case LocalDateTimeEncoder => createSerializerForLocalDateTime(input)
+ case LocalTimeEncoder => createSerializerForLocalTime(input)
case UDTEncoder(udt, udtClass) => createSerializerForUserDefinedType(input, udt, udtClass)
case OptionEncoder(valueEnc) =>
createSerializer(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc), input))
@@ -418,18 +429,21 @@ object SerializerBuildHelper {
}
createSerializerForObject(input, serializedFields)
- case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec =>
+ case TransformingEncoder(_, _, codec, _) if codec == JavaSerializationCodec =>
EncodeUsingSerializer(input, kryo = false)
- case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec =>
+ case TransformingEncoder(_, _, codec, _) if codec == KryoSerializationCodec =>
EncodeUsingSerializer(input, kryo = true)
- case TransformingEncoder(_, encoder, codecProvider) =>
+ case TransformingEncoder(_, encoder, codecProvider, _) =>
val encoded = Invoke(
Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])),
"encode",
externalDataTypeFor(encoder),
- input :: Nil)
+ input :: Nil,
+ propagateNull = input.nullable,
+ returnNullable = input.nullable
+ )
createSerializer(encoder, encoded)
}
@@ -486,6 +500,7 @@ object SerializerBuildHelper {
nullable: Boolean): Expression => Expression = { input =>
val expected = enc match {
case OptionEncoder(_) => lenientExternalDataTypeFor(enc)
+ case TransformingEncoder(_, transformed, _, _) => lenientExternalDataTypeFor(transformed)
case _ => enc.dataType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 81e8ac02000e5..fc895d60fad9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -158,6 +158,7 @@ case class AnalysisContext(
referredTempVariableNames: Seq[Seq[String]] = Seq.empty,
outerPlan: Option[LogicalPlan] = None,
isExecuteImmediate: Boolean = false,
+ collation: Option[String] = None,
/**
* This is a bridge state between this fixed-point [[Analyzer]] and a single-pass [[Resolver]].
@@ -213,7 +214,8 @@ object AnalysisContext {
viewDesc.viewReferredTempViewNames,
mutable.Set(viewDesc.viewReferredTempFunctionNames: _*),
viewDesc.viewReferredTempVariableNames,
- isExecuteImmediate = originContext.isExecuteImmediate)
+ isExecuteImmediate = originContext.isExecuteImmediate,
+ collation = viewDesc.collation)
set(context)
try f finally { set(originContext) }
}
@@ -336,7 +338,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
TypeCoercion.typeCoercionRules
}
- override def batches: Seq[Batch] = Seq(
+ private def earlyBatches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
// This rule optimizes `UpdateFields` expression chains so looks more like optimization rule.
// However, when manipulating deeply nested schema, `UpdateFields` expression tree could be
@@ -346,7 +348,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
- SubstituteUnresolvedOrdinals,
EliminateLazyExpression),
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
@@ -357,7 +358,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Batch("Simple Sanity Check", Once,
LookupFunctions),
Batch("Keep Legacy Outputs", Once,
- KeepLegacyOutputs),
+ KeepLegacyOutputs)
+ )
+
+ override def batches: Seq[Batch] = earlyBatches ++ Seq(
Batch("Resolution", fixedPoint,
new ResolveCatalogs(catalogManager) ::
ResolveInsertInto ::
@@ -387,10 +391,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveTableSpec ::
ValidateAndStripPipeExpressions ::
ResolveSQLFunctions ::
+ ResolveSQLTableFunctions ::
ResolveAliases ::
ResolveSubquery ::
ResolveSubqueryColumnAliases ::
- ResolveDDLCommandStringTypes ::
+ ApplyDefaultCollationToStringType ::
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
@@ -408,7 +413,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveTimeZone ::
ResolveRandomSeed ::
ResolveBinaryArithmetic ::
- ResolveIdentifierClause ::
+ new ResolveIdentifierClause(earlyBatches) ::
ResolveUnion ::
ResolveRowLevelCommandAssignments ::
MoveParameterizedQueriesDown ::
@@ -448,7 +453,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Batch("DML rewrite", fixedPoint,
RewriteDeleteFromTable,
RewriteUpdateTable,
- RewriteMergeIntoTable),
+ RewriteMergeIntoTable,
+ // Ensures columns of an output table are correctly resolved from the data in a logical plan.
+ ResolveOutputRelation),
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
@@ -456,7 +463,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Batch("HandleSpecialCommand", Once,
HandleSpecialCommand),
Batch("Remove watermark for batch query", Once,
- EliminateEventTimeWatermark)
+ EliminateEventTimeWatermark),
+ Batch("ResolveUnresolvedHaving", Once, ResolveUnresolvedHaving)
)
/**
@@ -520,7 +528,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
object ResolveAliases extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan =
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val collatedPlan =
+ if (conf.getConf(SQLConf.RUN_COLLATION_TYPE_CASTS_BEFORE_ALIAS_ASSIGNMENT)) {
+ CollationTypeCasts(plan)
+ } else {
+ plan
+ }
+ doApply(collatedPlan)
+ }
+
+ private def doApply(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS), ruleId) {
case Aggregate(groups, aggs, child, _)
if child.resolved && AliasResolution.hasUnresolvedAlias(aggs) =>
@@ -556,6 +574,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
if c.child.resolved && AliasResolution.hasUnresolvedAlias(c.metrics) =>
c.copy(metrics = AliasResolution.assignAliases(c.metrics))
}
+ }
}
object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
@@ -982,25 +1001,30 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
object AddMetadataColumns extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.util._
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDownWithPruning(
- AlwaysProcess.fn, ruleId) {
- case hint: UnresolvedHint => hint
- // Add metadata output to all node types
- case node if node.children.nonEmpty && node.resolved && hasMetadataCol(node) =>
- val inputAttrs = AttributeSet(node.children.flatMap(_.output))
- val metaCols = getMetadataAttributes(node).filterNot(inputAttrs.contains)
- if (metaCols.isEmpty) {
- node
- } else {
- val newNode = node.mapChildren(addMetadataCol(_, metaCols.map(_.exprId).toSet))
- // We should not change the output schema of the plan. We should project away the extra
- // metadata columns if necessary.
- if (newNode.sameOutput(node)) {
- newNode
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val onlyUniqueAndNecessaryMetadataColumns =
+ conf.getConf(SQLConf.ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS)
+ plan.resolveOperatorsDownWithPruning(AlwaysProcess.fn, ruleId) {
+ case hint: UnresolvedHint => hint
+ // Add metadata output to all node types
+ case node if node.children.nonEmpty && node.resolved && hasMetadataCol(node) =>
+ val inputAttrs = AttributeSet(node.children.flatMap(_.output))
+ val metaCols = getMetadataAttributes(node).filterNot(inputAttrs.contains)
+ if (metaCols.isEmpty) {
+ node
} else {
- Project(node.output, newNode)
+ val newNode = node.mapChildren(
+ addMetadataCol(_, metaCols.map(_.exprId).toSet, onlyUniqueAndNecessaryMetadataColumns)
+ )
+ // We should not change the output schema of the plan. We should project away the extra
+ // metadata columns if necessary.
+ if (newNode.sameOutput(node)) {
+ newNode
+ } else {
+ Project(node.output, newNode)
+ }
}
- }
+ }
}
private def getMetadataAttributes(plan: LogicalPlan): Seq[Attribute] = {
@@ -1028,18 +1052,32 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
private def addMetadataCol(
plan: LogicalPlan,
- requiredAttrIds: Set[ExprId]): LogicalPlan = plan match {
+ requiredAttrIds: Set[ExprId],
+ onlyUniqueAndNecessaryMetadataColumns: Boolean = true): LogicalPlan = plan match {
case s: ExposesMetadataColumns if s.metadataOutput.exists( a =>
requiredAttrIds.contains(a.exprId)) =>
s.withMetadataColumns()
case p: Project if p.metadataOutput.exists(a => requiredAttrIds.contains(a.exprId)) =>
+ val uniqueMetadataColumns = if (onlyUniqueAndNecessaryMetadataColumns) {
+ val actualRequiredExprIds = new util.HashSet[ExprId](requiredAttrIds.asJava)
+ p.projectList.foreach(ne => actualRequiredExprIds.remove(ne.exprId))
+ p.metadataOutput.filter(attr => actualRequiredExprIds.contains(attr.exprId))
+ } else {
+ p.metadataOutput
+ }
+
val newProj = p.copy(
// Do not leak the qualified-access-only restriction to normal plan outputs.
- projectList = p.projectList ++ p.metadataOutput.map(_.markAsAllowAnyAccess()),
- child = addMetadataCol(p.child, requiredAttrIds))
+ projectList = p.projectList ++ uniqueMetadataColumns.map(_.markAsAllowAnyAccess()),
+ child = addMetadataCol(p.child, requiredAttrIds, onlyUniqueAndNecessaryMetadataColumns)
+ )
newProj.copyTagsFrom(p)
newProj
- case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, requiredAttrIds)))
+ case _ =>
+ plan.withNewChildren(
+ plan.children
+ .map(addMetadataCol(_, requiredAttrIds, onlyUniqueAndNecessaryMetadataColumns))
+ )
}
}
@@ -1188,8 +1226,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// We put the synchronously resolved relation into the [[AnalyzerBridgeState]] for
// it to be later reused by the single-pass [[Resolver]] to avoid resolving the relation
// metadata twice.
- AnalysisContext.get.getSinglePassResolverBridgeState.map { bridgeState =>
- bridgeState.relationsWithResolvedMetadata.put(unresolvedRelation, relation)
+ AnalysisContext.get.getSinglePassResolverBridgeState.foreach { bridgeState =>
+ bridgeState.addUnresolvedRelation(unresolvedRelation, relation)
}
relation
}
@@ -1796,7 +1834,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
- exprs.exists(_.collect { case _: Star => true }.nonEmpty)
+ exprs.exists(_.collectFirst { case _: Star => true }.nonEmpty)
private def extractStar(exprs: Seq[Expression]): Seq[Star] =
exprs.flatMap(_.collect { case s: Star => s })
@@ -1936,24 +1974,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
withPosition(ordinal) {
if (index > 0 && index <= aggs.size) {
val ordinalExpr = aggs(index - 1)
+
if (ordinalExpr.exists(_.isInstanceOf[AggregateExpression])) {
throw QueryCompilationErrors.groupByPositionRefersToAggregateFunctionError(
index, ordinalExpr)
- } else {
- trimAliases(ordinalExpr) match {
- // HACK ALERT: If the ordinal expression is also an integer literal, don't use it
- // but still keep the ordinal literal. The reason is we may repeatedly
- // analyze the plan. Using a different integer literal may lead to
- // a repeat GROUP BY ordinal resolution which is wrong. GROUP BY
- // constant is meaningless so whatever value does not matter here.
- // TODO: (SPARK-45932) GROUP BY ordinal should pull out grouping expressions to
- // a Project, then the resolved ordinal expression is always
- // `AttributeReference`.
- case Literal(_: Int, IntegerType) =>
- Literal(index)
- case _ => ordinalExpr
- }
}
+
+ ordinalExpr
} else {
throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size)
}
@@ -2655,6 +2682,93 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}
+ /*
+ * This rule resolves SQL table functions.
+ */
+ object ResolveSQLTableFunctions extends Rule[LogicalPlan] with AliasHelper {
+
+ /**
+ * Check if a subquery plan is subject to the COUNT bug that can cause wrong results.
+ * A lateral correlation can only be removed if the lateral subquery is not subject to
+ * the COUNT bug. Currently only lateral correlation can handle it correctly.
+ */
+ private def hasCountBug(sub: LogicalPlan): Boolean = sub.find {
+ // The COUNT bug occurs when there is an Aggregate that satisfies all the following
+ // conditions:
+ // 1) is on the correlation path
+ // 2) has non-empty group by expressions
+ // 3) has one or more output columns that evaluate to non-null values with empty input.
+ // E.g: COUNT(empty row) = 0.
+ // For simplicity, we use a stricter criteria (1 and 2 only) to determine if a query
+ // is subject to the COUNT bug.
+ case a: Aggregate if a.groupingExpressions.nonEmpty => hasOuterReferences(a.child)
+ case _ => false
+ }.nonEmpty
+
+ /**
+ * Rewrite a resolved SQL table function plan by removing unnecessary lateral joins:
+ * Before:
+ * LateralJoin lateral-subquery [a], Inner
+ * : +- Project [c1, c2]
+ * : +- Filter [outer(a) == c1]
+ * : +- Relation [c1, c2]
+ * +- Project [1 AS a]
+ * +- OneRowRelation
+ * After:
+ * Project [c1, c2]
+ * +- Filter [1 == c1] <---- Replaced outer(a)
+ * +- Relation [c1, c2]
+ */
+ private def rewrite(plan: LogicalPlan): LogicalPlan = {
+ (plan transformUp {
+ case j @ LateralJoin(Project(aliases, _: OneRowRelation), sub: LateralSubquery, Inner, None)
+ if j.resolved && aliases.forall(_.deterministic) =>
+ val attrMap = AttributeMap(aliases.collect { case a: Alias => a.toAttribute -> a.child })
+ val newPlan = sub.plan.transformAllExpressionsWithPruning(
+ _.containsPattern(OUTER_REFERENCE)) {
+ // Avoid replacing outer references that do not belong to the current outer plan.
+ // This can happen if the child of an alias also contains outer references (nested
+ // table function references). E.g:
+ // LateralJoin
+ // : +- Filter [outer(a) == x]
+ // : +- Relation [x, y]
+ // +- Project [outer(c) AS a]
+ // +- OneRowRelation
+ case OuterReference(a: Attribute) if attrMap.contains(a) => attrMap(a) match {
+ case ne: NamedExpression => ne
+ case o => Alias(o, a.name)(exprId = a.exprId, qualifier = a.qualifier)
+ }
+ }
+ // Keep the original lateral join if the new plan is subject to the count bug.
+ if (hasCountBug(newPlan)) j else newPlan
+ }).transformWithPruning(_.containsPattern(ALIAS)) {
+ // As a result of the above rewriting, we may end-up introducing nested Aliases (i.e.,
+ // Aliases defined inside Aliases). This is problematic for the plan canonicalization as
+ // it doesn't expect nested Aliases. Therefore, here we remove non-top level Aliases.
+ case node => node.mapExpressions(trimNonTopLevelAliases)
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
+ _.containsPattern(SQL_TABLE_FUNCTION)) {
+ case SQLTableFunction(name, function, inputs, output) =>
+ // Resolve the SQL table function plan using its function context.
+ val conf = new SQLConf()
+ function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k, v) }
+ val resolved = SQLConf.withExistingConf(conf) {
+ val plan = v1SessionCatalog.makeSQLTableFunctionPlan(name, function, inputs, output)
+ SQLFunctionContext.withSQLFunction {
+ executeSameContext(plan)
+ }
+ }
+ // Remove unnecessary lateral joins that are used to resolve the SQL function.
+ val newPlan = rewrite(resolved)
+ // Fail the analysis eagerly if a SQL table function cannot be resolved using its input.
+ SimpleAnalyzer.checkAnalysis(newPlan)
+ newPlan
+ }
+ }
+
/**
* Turns projections that contain aggregate expressions into aggregations.
*/
@@ -2692,7 +2806,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* and group by expressions from them.
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val collatedPlan =
+ if (conf.getConf(SQLConf.RUN_COLLATION_TYPE_CASTS_BEFORE_ALIAS_ASSIGNMENT)) {
+ CollationTypeCasts(plan)
+ } else {
+ plan
+ }
+ doApply(collatedPlan)
+ }
+
+ def doApply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved && cond.resolved =>
resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => {
@@ -2768,7 +2892,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
expr match {
case ae: AggregateExpression =>
val cleaned = trimTempResolvedColumn(ae)
- val alias = Alias(cleaned, cleaned.toString)()
+ val alias = Alias(cleaned, toPrettySQL(cleaned))()
aggExprList += alias
alias.toAttribute
case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) =>
@@ -2777,7 +2901,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
aggExprList += ne
ne.toAttribute
case other =>
- val alias = Alias(other, other.toString)()
+ val alias = Alias(other, toPrettySQL(other))()
aggExprList += alias
alias.toAttribute
}
@@ -2893,22 +3017,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}
- // We must wait until all expressions except for generator functions are resolved before
- // rewriting generator functions in Project/Aggregate. This is necessary to make this rule
- // stable for different execution orders of analyzer rules. See also SPARK-47241.
- private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean = {
- namedExprs.forall { ne =>
- ne.resolved || {
- trimNonTopLevelAliases(ne) match {
- case AliasedGenerator(_, _, _) => true
- case _ => false
- }
- }
- }
- }
-
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(GENERATOR), ruleId) {
+ case p @ Project(Seq(UnresolvedStarWithColumns(_, _, _)), _) =>
+ // UnresolvedStarWithColumns should be resolved before extracting.
+ p
+
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
@@ -2921,8 +3035,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
val generators = aggList.filter(hasGenerator).map(trimAlias)
throw QueryCompilationErrors.moreThanOneGeneratorError(generators)
- case Aggregate(groupList, aggList, child, _) if canRewriteGenerator(aggList) &&
- aggList.exists(hasGenerator) =>
+ case Aggregate(groupList, aggList, child, _) if
+ aggList.forall {
+ case AliasedGenerator(_, _, _) => true
+ case other => other.resolved
+ } && aggList.exists(hasGenerator) =>
// If generator in the aggregate list was visited, set the boolean flag true.
var generatorVisited = false
@@ -2967,8 +3084,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// first for replacing `Project` with `Aggregate`.
p
- case p @ Project(projectList, child) if canRewriteGenerator(projectList) &&
- projectList.exists(hasGenerator) =>
+ // The star will be expanded differently if we insert `Generate` under `Project` too early.
+ case p @ Project(projectList, child) if !projectList.exists(_.exists(_.isInstanceOf[Star])) =>
val (resolvedGenerator, newProjectList) = projectList
.map(trimNonTopLevelAliases)
.foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) { (res, e) =>
@@ -3534,7 +3651,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
TableOutputResolver.suitableForByNameCheck(v2Write.isByName,
expected = v2Write.table.output, queryOutput = v2Write.query.output)
val projection = TableOutputResolver.resolveOutputColumns(
- v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf)
+ v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf,
+ supportColDefaultValue = true)
if (projection != v2Write.query) {
val cleanedTable = v2Write.table match {
case r: DataSourceV2Relation =>
@@ -4081,3 +4199,18 @@ object RemoveTempResolvedColumn extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Rule that's used to handle `UnresolvedHaving` nodes with resolved `condition` and `child`.
+ * It's placed outside the main batch to avoid conflicts with other rules that resolve
+ * `UnresolvedHaving` in the main batch.
+ */
+object ResolveUnresolvedHaving extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ plan.resolveOperatorsWithPruning(_.containsPattern(UNRESOLVED_HAVING), ruleId) {
+ case u @ UnresolvedHaving(havingCondition, child)
+ if havingCondition.resolved && child.resolved =>
+ Filter(condition = havingCondition, child = child)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index aa977b240007b..a98aaf702acef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -122,7 +122,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
Some(widerType)
}
- case (d1: DatetimeType, d2: DatetimeType) => Some(findWiderDateTimeType(d1, d2))
+ case (d1: DatetimeType, d2: DatetimeType) => findWiderDateTimeType(d1, d2)
case (t1: DayTimeIntervalType, t2: DayTimeIntervalType) =>
Some(DayTimeIntervalType(t1.startField.min(t2.startField), t1.endField.max(t2.endField)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala
index 8df977c809211..cf7ea21ee6f72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala
@@ -68,7 +68,7 @@ object ApplyCharTypePaddingHelper {
private[sql] def paddingForStringComparison(
plan: LogicalPlan,
padCharCol: Boolean): LogicalPlan = {
- plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
+ plan.resolveOperatorsUpWithSubqueriesAndPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
case operator =>
operator.transformExpressionsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
case e if !e.childrenResolved => e
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDDLCommandStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala
similarity index 67%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDDLCommandStringTypes.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala
index 9ac04236a1b13..cea2988badf4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDDLCommandStringTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala
@@ -17,71 +17,79 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.expressions.{Cast, DefaultStringProducingExpression, Expression, Literal}
-import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns, AlterColumnSpec, AlterTableCommand, AlterViewAs, ColumnDefinition, CreateTable, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan}
+import org.apache.spark.sql.catalyst.expressions.{Cast, DefaultStringProducingExpression, Expression, Literal, SubqueryExpression}
+import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns, AlterColumnSpec, AlterTableCommand, AlterViewAs, ColumnDefinition, CreateTable, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, ReplaceTable, V2CreateTablePlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.TableCatalog
import org.apache.spark.sql.types.{DataType, StringType}
/**
- * Resolves string types in DDL commands, where the string type inherits the
- * collation from the corresponding object (table/view -> schema -> catalog).
+ * Resolves string types in logical plans by assigning them the appropriate collation. The
+ * collation is inherited from the relevant object in the hierarchy (e.g., table/view -> schema ->
+ * catalog). This rule is primarily applied to DDL commands, but it can also be triggered in other
+ * scenarios. For example, when querying a view, its query is re-resolved each time, and that query
+ * can take various forms.
*/
-object ResolveDDLCommandStringTypes extends Rule[LogicalPlan] {
+object ApplyDefaultCollationToStringType extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
- if (isDDLCommand(plan)) {
- transformDDL(plan)
- } else {
- // For non-DDL commands no need to do any further resolution of string types
- plan
+ fetchDefaultCollation(plan) match {
+ case Some(collation) =>
+ transform(plan, StringType(collation))
+ case None => plan
}
}
- /** Default collation used, if object level collation is not provided */
- private def defaultCollation: String = "UTF8_BINARY"
+ /** Returns the default collation that should be applied to the plan
+ * if specified; otherwise, returns None.
+ */
+ private def fetchDefaultCollation(plan: LogicalPlan): Option[String] = {
+ plan match {
+ case createTable: CreateTable =>
+ createTable.tableSpec.collation
- /** Returns the string type that should be used in a given DDL command */
- private def stringTypeForDDLCommand(table: LogicalPlan): StringType = {
- table match {
- case createTable: CreateTable if createTable.tableSpec.collation.isDefined =>
- StringType(createTable.tableSpec.collation.get)
+ // CreateView also handles CREATE OR REPLACE VIEW
+ // Unlike for tables, CreateView also handles CREATE OR REPLACE VIEW
+ case createView: CreateView =>
+ createView.collation
- case createView: CreateView if createView.collation.isDefined =>
- StringType(createView.collation.get)
+ case replaceTable: ReplaceTable =>
+ replaceTable.tableSpec.collation
case alterTable: AlterTableCommand if alterTable.table.resolved =>
alterTable.table match {
- case resolvedTbl: ResolvedTable =>
- val collation = resolvedTbl.table.properties.getOrDefault(
- TableCatalog.PROP_COLLATION, defaultCollation)
- StringType(collation)
-
- case _ =>
- // As a safeguard, use the default collation for unknown cases.
- StringType(defaultCollation)
+ case resolvedTbl: ResolvedTable
+ if resolvedTbl.table.properties.containsKey(TableCatalog.PROP_COLLATION ) =>
+ Some(resolvedTbl.table.properties.get(TableCatalog.PROP_COLLATION))
+ case _ => None
}
- case _ => StringType(defaultCollation)
- }
- }
+ case alterViewAs: AlterViewAs =>
+ alterViewAs.child match {
+ case resolvedPersistentView: ResolvedPersistentView =>
+ resolvedPersistentView.metadata.collation
+ case resolvedTempView: ResolvedTempView =>
+ resolvedTempView.metadata.collation
+ case _ => None
+ }
+
+ // Check if view has default collation
+ case _ if AnalysisContext.get.collation.isDefined =>
+ AnalysisContext.get.collation
- private def isDDLCommand(plan: LogicalPlan): Boolean = plan exists {
- case _: AddColumns | _: ReplaceColumns | _: AlterColumns => true
- case _ => isCreateOrAlterPlan(plan)
+ case _ => None
+ }
}
private def isCreateOrAlterPlan(plan: LogicalPlan): Boolean = plan match {
// For CREATE TABLE, only v2 CREATE TABLE command is supported.
// Also, table DEFAULT COLLATION cannot be specified through CREATE TABLE AS SELECT command.
- case _: V2CreateTablePlan | _: CreateView | _: AlterViewAs => true
+ case _: V2CreateTablePlan | _: ReplaceTable | _: CreateView | _: AlterViewAs => true
case _ => false
}
- private def transformDDL(plan: LogicalPlan): LogicalPlan = {
- val newType = stringTypeForDDLCommand(plan)
-
+ private def transform(plan: LogicalPlan, newType: StringType): LogicalPlan = {
plan resolveOperators {
- case p if isCreateOrAlterPlan(p) =>
+ case p if isCreateOrAlterPlan(p) || AnalysisContext.get.collation.isDefined =>
transformPlan(p, newType)
case addCols: AddColumns =>
@@ -121,11 +129,22 @@ object ResolveDDLCommandStringTypes extends Rule[LogicalPlan] {
case columnDef: ColumnDefinition if hasDefaultStringType(columnDef.dataType) =>
newType => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType))
- case cast: Cast if hasDefaultStringType(cast.dataType) =>
+ case cast: Cast if hasDefaultStringType(cast.dataType) &&
+ cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined =>
newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType))
case Literal(value, dt) if hasDefaultStringType(dt) =>
newType => Literal(value, replaceDefaultStringType(dt, newType))
+
+ case subquery: SubqueryExpression =>
+ val plan = subquery.plan
+ newType =>
+ val newPlan = plan resolveExpressionsUp { expression =>
+ transformExpression
+ .andThen(_.apply(newType))
+ .applyOrElse(expression, identity[Expression])
+ }
+ subquery.withNewPlan(newPlan)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
index ef13bc191db5c..19e58a6e370b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
@@ -56,9 +56,15 @@ object CTESubstitution extends Rule[LogicalPlan] {
return plan
}
- val commands = plan.collect {
- case c @ (_: Command | _: ParsedStatement | _: InsertIntoDir) => c
+ def collectCommands(p: LogicalPlan): Seq[LogicalPlan] = p match {
+ case c @ (_: Command | _: ParsedStatement | _: InsertIntoDir) => Seq(c)
+ case u: UnresolvedWith =>
+ collectCommands(u.child) ++ u.cteRelations.flatMap {
+ case (_, relation) => collectCommands(relation)
+ }
+ case p => p.children.flatMap(collectCommands)
}
+ val commands = collectCommands(plan)
val forceInline = if (commands.length == 1) {
if (conf.getConf(SQLConf.LEGACY_INLINE_CTE_IN_COMMANDS)) {
// The legacy behavior always inlines the CTE relations for queries in commands.
@@ -78,14 +84,14 @@ object CTESubstitution extends Rule[LogicalPlan] {
val cteDefs = ArrayBuffer.empty[CTERelationDef]
val (substituted, firstSubstituted) =
- LegacyBehaviorPolicy.withName(conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY)) match {
+ conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY) match {
case LegacyBehaviorPolicy.EXCEPTION =>
assertNoNameConflictsInCTE(plan)
- traverseAndSubstituteCTE(plan, forceInline, Seq.empty, cteDefs)
+ traverseAndSubstituteCTE(plan, forceInline, Seq.empty, cteDefs, None)
case LegacyBehaviorPolicy.LEGACY =>
(legacyTraverseAndSubstituteCTE(plan, cteDefs), None)
case LegacyBehaviorPolicy.CORRECTED =>
- traverseAndSubstituteCTE(plan, forceInline, Seq.empty, cteDefs)
+ traverseAndSubstituteCTE(plan, forceInline, Seq.empty, cteDefs, None)
}
if (cteDefs.isEmpty) {
substituted
@@ -156,7 +162,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
messageParameters = Map.empty)
}
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = true,
- forceInline = false, Seq.empty, cteDefs, allowRecursion)
+ forceInline = false, Seq.empty, cteDefs, None, allowRecursion)
substituteCTE(child, alwaysInline = true, resolvedCTERelations, None)
}
}
@@ -196,6 +202,8 @@ object CTESubstitution extends Rule[LogicalPlan] {
* @param forceInline always inline the CTE relations if this is true
* @param outerCTEDefs already resolved outer CTE definitions with names
* @param cteDefs all accumulated CTE definitions
+ * @param recursiveCTERelationAncestor contains information of whether we are in a recursive CTE,
+ * as well as what CTE that is.
* @return the plan where CTE substitution is applied and optionally the last substituted `With`
* where CTE definitions will be gathered to
*/
@@ -203,7 +211,9 @@ object CTESubstitution extends Rule[LogicalPlan] {
plan: LogicalPlan,
forceInline: Boolean,
outerCTEDefs: Seq[(String, CTERelationDef)],
- cteDefs: ArrayBuffer[CTERelationDef]): (LogicalPlan, Option[LogicalPlan]) = {
+ cteDefs: ArrayBuffer[CTERelationDef],
+ recursiveCTERelationAncestor: Option[(String, CTERelationDef)]
+ ): (LogicalPlan, Option[LogicalPlan]) = {
var firstSubstituted: Option[LogicalPlan] = None
val newPlan = plan.resolveOperatorsDownWithPruning(
_.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) {
@@ -214,18 +224,31 @@ object CTESubstitution extends Rule[LogicalPlan] {
errorClass = "RECURSIVE_CTE_WHEN_INLINING_IS_FORCED",
messageParameters = Map.empty)
}
- val resolvedCTERelations =
+
+ val tempCteDefs = ArrayBuffer.empty[CTERelationDef]
+ val resolvedCTERelations = if (recursiveCTERelationAncestor.isDefined) {
+ resolveCTERelations(relations, isLegacy = false, forceInline = false, outerCTEDefs,
+ tempCteDefs, recursiveCTERelationAncestor, allowRecursion) ++ outerCTEDefs
+ } else {
resolveCTERelations(relations, isLegacy = false, forceInline, outerCTEDefs, cteDefs,
- allowRecursion) ++ outerCTEDefs
+ recursiveCTERelationAncestor, allowRecursion) ++ outerCTEDefs
+ }
val substituted = substituteCTE(
- traverseAndSubstituteCTE(child, forceInline, resolvedCTERelations, cteDefs)._1,
+ traverseAndSubstituteCTE(child, forceInline, resolvedCTERelations, cteDefs,
+ recursiveCTERelationAncestor)._1,
+ // If we are resolving CTEs in a recursive CTE, we need to inline it in case the
+ // CTE contains the self reference.
forceInline,
resolvedCTERelations,
None)
if (firstSubstituted.isEmpty) {
firstSubstituted = Some(substituted)
}
- substituted
+ if (recursiveCTERelationAncestor.isDefined) {
+ withCTEDefs(substituted, tempCteDefs.toSeq)
+ } else {
+ substituted
+ }
case other =>
other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
@@ -241,6 +264,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
forceInline: Boolean,
outerCTEDefs: Seq[(String, CTERelationDef)],
cteDefs: ArrayBuffer[CTERelationDef],
+ recursiveCTERelationAncestor: Option[(String, CTERelationDef)],
allowRecursion: Boolean): Seq[(String, CTERelationDef)] = {
val alwaysInline = isLegacy || forceInline
var resolvedCTERelations = if (alwaysInline) {
@@ -249,6 +273,21 @@ object CTESubstitution extends Rule[LogicalPlan] {
outerCTEDefs
}
for ((name, relation) <- relations) {
+ // If recursion is allowed (RECURSIVE keyword specified)
+ // then it has higher priority than outer or previous relations.
+ // Therefore, we construct a `CTERelationDef` for the current relation.
+ // Later if we encounter unresolved relation which we need to find which CTE Def it is
+ // referencing to, we first check if it is a reference to this one. If yes, then we set the
+ // reference as being recursive.
+ val recursiveCTERelation = if (allowRecursion) {
+ Some(name -> CTERelationDef(relation))
+ } else {
+ // If there is an outer recursive CTE relative to this one, and this one isn't recursive,
+ // then the self reference with the first-check priority is going to be the CteRelationDef
+ // of this recursive ancestor.
+ recursiveCTERelationAncestor
+ }
+
val innerCTEResolved = if (isLegacy) {
// In legacy mode, outer CTE relations take precedence. Here we don't resolve the inner
// `With` nodes, later we will substitute `UnresolvedRelation`s with outer CTE relations.
@@ -299,26 +338,20 @@ object CTESubstitution extends Rule[LogicalPlan] {
} else {
resolvedCTERelations
}
- traverseAndSubstituteCTE(relation, forceInline, nonConflictingCTERelations, cteDefs)._1
+ traverseAndSubstituteCTE(relation, forceInline, nonConflictingCTERelations,
+ cteDefs, recursiveCTERelation)._1
}
- // If recursion is allowed (RECURSIVE keyword specified)
- // then it has higher priority than outer or previous relations.
- // Therefore, we construct a `CTERelationDef` for the current relation.
- // Later if we encounter unresolved relation which we need to find which CTE Def it is
- // referencing to, we first check if it is a reference to this one. If yes, then we set the
- // reference as being recursive.
- val recursiveCTERelation = if (allowRecursion) {
- Some(name -> CTERelationDef(relation))
- } else {
- None
- }
// CTE definition can reference a previous one or itself if recursion allowed.
val substituted = substituteCTE(innerCTEResolved, alwaysInline,
resolvedCTERelations, recursiveCTERelation)
- val cteRelation = recursiveCTERelation
+ val cteRelation = if (allowRecursion) {
+ recursiveCTERelation
.map(_._2.copy(child = substituted))
.getOrElse(CTERelationDef(substituted))
+ } else {
+ CTERelationDef(substituted)
+ }
if (!alwaysInline) {
cteDefs += cteRelation
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 696c1e10d060f..65a0647ce92b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -21,12 +21,11 @@ import scala.collection.mutable
import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
-import org.apache.spark.sql.catalyst.analysis.ResolveWithCTE.{checkForSelfReferenceInSubquery, checkIfSelfReferenceIsPlacedCorrectly}
+import org.apache.spark.sql.catalyst.analysis.ResolveWithCTE.checkIfSelfReferenceIsPlacedCorrectly
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ListAgg, Median, PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.optimizer.InlineCTE
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, PLAN_EXPRESSION, UNRESOLVED_WINDOW_EXPRESSION}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils}
import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement}
@@ -51,9 +50,6 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
*/
val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
- val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Unit]("dataTypeMismatchError")
- val INVALID_FORMAT_ERROR = TreeNodeTag[Unit]("invalidFormatError")
-
// Error that is not supposed to throw immediately on triggering, e.g. certain internal errors.
// The error will be thrown at the end of the whole check analysis process, if no other error
// occurs.
@@ -279,9 +275,6 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
case _ =>
}
- // Check if there is any self-reference within subqueries
- checkForSelfReferenceInSubquery(plan)
-
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
@@ -348,6 +341,16 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
case operator: LogicalPlan =>
operator transformExpressionsDown {
+ case hof: HigherOrderFunction if hof.arguments.exists {
+ case LambdaFunction(_, _, _) => true
+ case _ => false
+ } =>
+ throw new AnalysisException(
+ errorClass =
+ "INVALID_LAMBDA_FUNCTION_CALL.PARAMETER_DOES_NOT_ACCEPT_LAMBDA_FUNCTION",
+ messageParameters = Map.empty,
+ origin = hof.origin
+ )
// Check argument data types of higher-order functions downwards first.
// If the arguments of the higher-order functions are resolved but the type check fails,
// the argument functions will not get resolved, but we should report the argument type
@@ -358,7 +361,6 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
case checkRes: TypeCheckResult.DataTypeMismatch =>
hof.dataTypeMismatch(hof, checkRes)
case checkRes: TypeCheckResult.InvalidFormat =>
- hof.setTagValue(INVALID_FORMAT_ERROR, ())
hof.invalidFormat(checkRes)
}
@@ -407,23 +409,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowName)
case e: Expression if e.checkInputDataTypes().isFailure =>
- e.checkInputDataTypes() match {
- case checkRes: TypeCheckResult.DataTypeMismatch =>
- e.setTagValue(DATA_TYPE_MISMATCH_ERROR, ())
- e.dataTypeMismatch(e, checkRes)
- case TypeCheckResult.TypeCheckFailure(message) =>
- e.setTagValue(DATA_TYPE_MISMATCH_ERROR, ())
- val extraHint = extraHintForAnsiTypeCoercionExpression(operator)
- e.failAnalysis(
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
- messageParameters = Map(
- "sqlExpr" -> toSQLExpr(e),
- "msg" -> message,
- "hint" -> extraHint))
- case checkRes: TypeCheckResult.InvalidFormat =>
- e.setTagValue(INVALID_FORMAT_ERROR, ())
- e.invalidFormat(checkRes)
- }
+ TypeCoercionValidation.failOnTypeCheckResult(e, Some(operator))
case c: Cast if !c.resolved =>
throw SparkException.internalError(
@@ -717,7 +703,8 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
)
}
- val dataTypesAreCompatibleFn = getDataTypesAreCompatibleFn(operator)
+ val dataTypesAreCompatibleFn =
+ TypeCoercionValidation.getDataTypesAreCompatibleFn(operator)
// Check if the data types match.
dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
// SPARK-18058: we shall not care about the nullability of columns
@@ -728,7 +715,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
tableOrdinalNumber = ti + 1,
dataType1 = dt1,
dataType2 = dt2,
- hint = extraHintForAnsiTypeCoercionPlan(operator),
+ hint = TypeCoercionValidation.getHintForOperatorCoercion(operator),
origin = operator.origin
)
}
@@ -869,20 +856,18 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
// used in equality comparison, remove this type check once we support it.
case o if mapColumnInSetOperation(o).isDefined =>
val mapCol = mapColumnInSetOperation(o).get
- o.failAnalysis(
- errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE",
- messageParameters = Map(
- "colName" -> toSQLId(mapCol.name),
- "dataType" -> toSQLType(mapCol.dataType)))
+ throw QueryCompilationErrors.unsupportedSetOperationOnMapType(
+ mapCol = mapCol,
+ origin = operator.origin
+ )
// TODO: Remove this type check once we support Variant ordering
case o if variantColumnInSetOperation(o).isDefined =>
val variantCol = variantColumnInSetOperation(o).get
- o.failAnalysis(
- errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_VARIANT_TYPE",
- messageParameters = Map(
- "colName" -> toSQLId(variantCol.name),
- "dataType" -> toSQLType(variantCol.dataType)))
+ throw QueryCompilationErrors.unsupportedSetOperationOnVariantType(
+ variantCol = variantCol,
+ origin = operator.origin
+ )
case o if variantExprInPartitionExpression(o).isDefined =>
val variantExpr = variantExprInPartitionExpression(o).get
@@ -962,85 +947,30 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
}
}
- private def getDataTypesAreCompatibleFn(plan: LogicalPlan): (DataType, DataType) => Boolean = {
- val isUnion = plan.isInstanceOf[Union]
- if (isUnion) {
- (dt1: DataType, dt2: DataType) =>
- DataType.equalsStructurally(dt1, dt2, true)
- } else {
- // SPARK-18058: we shall not care about the nullability of columns
- (dt1: DataType, dt2: DataType) =>
- TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).nonEmpty
- }
- }
-
- private def getDefaultTypeCoercionPlan(plan: LogicalPlan): LogicalPlan =
- TypeCoercion.typeCoercionRules.foldLeft(plan) { case (p, rule) => rule(p) }
-
- private def extraHintMessage(issueFixedIfAnsiOff: Boolean): String = {
- if (issueFixedIfAnsiOff) {
- "\nTo fix the error, you might need to add explicit type casts. If necessary set " +
- s"${SQLConf.ANSI_ENABLED.key} to false to bypass this error."
- } else {
- ""
- }
- }
-
- private[analysis] def extraHintForAnsiTypeCoercionExpression(plan: LogicalPlan): String = {
- if (!SQLConf.get.ansiEnabled) {
- ""
- } else {
- val nonAnsiPlan = getDefaultTypeCoercionPlan(plan)
- var issueFixedIfAnsiOff = true
- getAllExpressions(nonAnsiPlan).foreach(_.foreachUp {
- case e: Expression if e.getTagValue(DATA_TYPE_MISMATCH_ERROR).isDefined &&
- e.checkInputDataTypes().isFailure =>
- e.checkInputDataTypes() match {
- case TypeCheckResult.TypeCheckFailure(_) | _: TypeCheckResult.DataTypeMismatch =>
- issueFixedIfAnsiOff = false
- }
-
- case _ =>
- })
- extraHintMessage(issueFixedIfAnsiOff)
- }
- }
-
- private def extraHintForAnsiTypeCoercionPlan(plan: LogicalPlan): String = {
- if (!SQLConf.get.ansiEnabled) {
- ""
- } else {
- val nonAnsiPlan = getDefaultTypeCoercionPlan(plan)
- var issueFixedIfAnsiOff = true
- nonAnsiPlan match {
- case _: Union | _: SetOperation if nonAnsiPlan.children.length > 1 =>
- def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
-
- val ref = dataTypes(nonAnsiPlan.children.head)
- val dataTypesAreCompatibleFn = getDataTypesAreCompatibleFn(nonAnsiPlan)
- nonAnsiPlan.children.tail.zipWithIndex.foreach { case (child, ti) =>
- // Check if the data types match.
- dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
- if (!dataTypesAreCompatibleFn(dt1, dt2)) {
- issueFixedIfAnsiOff = false
- }
- }
- }
-
- case _ =>
- }
- extraHintMessage(issueFixedIfAnsiOff)
- }
- }
-
def checkSubqueryExpression(plan: LogicalPlan, expr: SubqueryExpression): Unit = {
if (expr.plan.isStreaming) {
plan.failAnalysis("INVALID_SUBQUERY_EXPRESSION.STREAMING_QUERY", Map.empty)
}
+ assertNoRecursiveCTE(expr.plan)
checkAnalysis0(expr.plan)
ValidateSubqueryExpression(plan, expr)
}
+ private def assertNoRecursiveCTE(plan: LogicalPlan): Unit = {
+ plan.foreach {
+ case r: CTERelationRef if r.recursive =>
+ throw new AnalysisException(
+ errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
+ messageParameters = Map.empty)
+ case p => p.expressions.filter(_.containsPattern(PLAN_EXPRESSION)).foreach {
+ expr => expr.foreach {
+ case s: SubqueryExpression => assertNoRecursiveCTE(s.plan)
+ case _ =>
+ }
+ }
+ }
+ }
+
/**
* Validate that collected metrics names are unique. The same name cannot be used for metrics
* with different results. However multiple instances of metrics with with same result and name
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index e778342d08374..b2e068fd990ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -406,7 +406,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
// Lateral column alias does not have qualifiers. We always use the first name part to
// look up lateral column aliases.
val lowerCasedName = u.nameParts.head.toLowerCase(Locale.ROOT)
- aliasMap.get(lowerCasedName).map {
+ aliasMap.get(lowerCasedName).filter {
+ // Do not resolve LCA with aliased `Generator`, as it will be rewritten by the rule
+ // `ExtractGenerator` with fresh output attribute IDs. The `Generator` will be pulled
+ // out and put in a `Generate` node below `Project`, so that we can resolve the column
+ // normally without LCA resolution.
+ case scala.util.Left(alias) => !alias.child.isInstanceOf[Generator]
+ case _ => true
+ }.map {
case scala.util.Left(alias) =>
if (alias.resolved) {
val resolvedAttr = resolveExpressionByPlanOutput(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 8398fb8d1e830..752a2a648ce99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -22,9 +22,12 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern._
object DeduplicateRelations extends Rule[LogicalPlan] {
+ val PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION =
+ TreeNodeTag[Unit]("project_for_expression_id_deduplication")
type ExprIdMap = mutable.HashMap[Class[_], mutable.HashSet[Long]]
@@ -67,7 +70,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
val projectList = child.output.map { attr =>
Alias(attr, attr.name)()
}
- Project(projectList, child)
+ val project = Project(projectList, child)
+ project.setTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION, ())
+ project
}
}
u.copy(children = newChildren)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 597ac57a10ccc..66db1fe8b5965 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -463,6 +463,7 @@ object FunctionRegistry {
expressionBuilder("try_sum", TrySumExpressionBuilder, setAlias = true),
expression[TryToBinary]("try_to_binary"),
expressionBuilder("try_to_timestamp", TryToTimestampExpressionBuilder, setAlias = true),
+ expressionBuilder("try_to_time", TryToTimeExpressionBuilder, setAlias = true),
expression[TryAesDecrypt]("try_aes_decrypt"),
expression[TryReflect]("try_reflect"),
expression[TryUrlDecode]("try_url_decode"),
@@ -626,6 +627,7 @@ object FunctionRegistry {
expression[CurrentDate]("current_date"),
expressionBuilder("curdate", CurDateExpressionBuilder, setAlias = true),
expression[CurrentTimestamp]("current_timestamp"),
+ expression[CurrentTime]("current_time"),
expression[CurrentTimeZone]("current_timezone"),
expression[LocalTimestamp]("localtimestamp"),
expression[DateDiff]("datediff"),
@@ -639,17 +641,18 @@ object FunctionRegistry {
expression[DayOfMonth]("dayofmonth"),
expression[FromUnixTime]("from_unixtime"),
expression[FromUTCTimestamp]("from_utc_timestamp"),
- expression[Hour]("hour"),
+ expressionBuilder("hour", HourExpressionBuilder),
expression[LastDay]("last_day"),
- expression[Minute]("minute"),
+ expressionBuilder("minute", MinuteExpressionBuilder),
expression[Month]("month"),
expression[MonthsBetween]("months_between"),
expression[NextDay]("next_day"),
expression[Now]("now"),
expression[Quarter]("quarter"),
- expression[Second]("second"),
+ expressionBuilder("second", SecondExpressionBuilder),
expression[ParseToTimestamp]("to_timestamp"),
expression[ParseToDate]("to_date"),
+ expression[ToTime]("to_time"),
expression[ToBinary]("to_binary"),
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
@@ -668,6 +671,7 @@ object FunctionRegistry {
expression[SessionWindow]("session_window"),
expression[WindowTime]("window_time"),
expression[MakeDate]("make_date"),
+ expression[MakeTime]("make_time"),
expression[MakeTimestamp]("make_timestamp"),
expression[TryMakeTimestamp]("try_make_timestamp"),
expression[MonthName]("monthname"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/KeepLegacyOutputs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/KeepLegacyOutputs.scala
index a40b96732bae2..3539e910169bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/KeepLegacyOutputs.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/KeepLegacyOutputs.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.plans.logical.{DescribeNamespace, LogicalPlan, ShowNamespaces, ShowTableProperties, ShowTables}
+import org.apache.spark.sql.catalyst.plans.logical.{DescribeNamespace, LogicalPlan, ShowTableProperties, ShowTables}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.internal.SQLConf
@@ -36,9 +36,6 @@ object KeepLegacyOutputs extends Rule[LogicalPlan] {
assert(s.output.length == 3)
val newOutput = s.output.head.withName("database") +: s.output.tail
s.copy(output = newOutput)
- case s: ShowNamespaces =>
- assert(s.output.length == 1)
- s.copy(output = Seq(s.output.head.withName("databaseName")))
case d: DescribeNamespace =>
assert(d.output.length == 2)
d.copy(output = Seq(d.output.head.withName("database_description_item"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LiteralFunctionResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LiteralFunctionResolution.scala
index c7faf0536b77d..865a780b61dae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LiteralFunctionResolution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LiteralFunctionResolution.scala
@@ -17,16 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.expressions.{
- Alias,
- CurrentDate,
- CurrentTimestamp,
- CurrentUser,
- Expression,
- GroupingID,
- NamedExpression,
- VirtualColumn
-}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTime, CurrentTimestamp, CurrentUser, Expression, GroupingID, NamedExpression, VirtualColumn}
import org.apache.spark.sql.catalyst.util.toPrettySQL
/**
@@ -47,10 +38,12 @@ object LiteralFunctionResolution {
}
}
- // support CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_USER, USER, SESSION_USER and grouping__id
+ // support CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_TIME,
+ // CURRENT_USER, USER, SESSION_USER and grouping__id
private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
+ (CurrentTime().prettyName, () => CurrentTime(), toPrettySQL(_)),
(CurrentUser().prettyName, () => CurrentUser(), toPrettySQL),
("user", () => CurrentUser(), toPrettySQL),
("session_user", () => CurrentUser(), toPrettySQL),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NormalizeableRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NormalizeableRelation.scala
new file mode 100644
index 0000000000000..08bca79094a4d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NormalizeableRelation.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * [[NormalizeableRelation]] is extended by relations that contain non-deterministic or
+ * time-dependent objects that need to be normalized in the [[NormalizePlan]]. This way logical
+ * plans produced by different calls to the Analyzer could be compared.
+ *
+ * This is used in unit tests and to check whether the plans from the fixed-point and the
+ * single-pass Analyzers are the same.
+ */
+trait NormalizeableRelation {
+ def normalize(): LogicalPlan
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala
index 2cf3c6390d5fb..de0262edccf69 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala
@@ -17,27 +17,91 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression, SubqueryExpression, VariableReference}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.plans.logical.{CreateView, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StringType
/**
* Resolves the identifier expressions and builds the original plans/expressions.
*/
-object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper {
+class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch])
+ extends Rule[LogicalPlan] with AliasHelper with EvalHelper {
+
+ private val executor = new RuleExecutor[LogicalPlan] {
+ override def batches: Seq[Batch] = earlyBatches.asInstanceOf[Seq[Batch]]
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ plan match {
+ case createView: CreateView =>
+ if (conf.getConf(SQLConf.VARIABLES_UNDER_IDENTIFIER_IN_VIEW)) {
+ apply0(createView)
+ } else {
+ val referredTempVars = new mutable.ArrayBuffer[Seq[String]]
+ val analyzedChild = apply0(createView.child)
+ val analyzedQuery = apply0(createView.query, Some(referredTempVars))
+ if (referredTempVars.nonEmpty) {
+ throw QueryCompilationErrors.notAllowedToCreatePermanentViewByReferencingTempVarError(
+ Seq("unknown"),
+ referredTempVars.head
+ )
+ }
+ createView.copy(child = analyzedChild, query = analyzedQuery)
+ }
+ case _ => apply0(plan)
+ }
+ }
+
+ private def apply0(
+ plan: LogicalPlan,
+ referredTempVars: Option[mutable.ArrayBuffer[Seq[String]]] = None): LogicalPlan =
+ plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_IDENTIFIER)) {
+ case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved =>
+
+ if (referredTempVars.isDefined) {
+ referredTempVars.get ++= collectTemporaryVariablesInLogicalPlan(p)
+ }
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
- _.containsPattern(UNRESOLVED_IDENTIFIER)) {
- case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved =>
- p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)
- case other =>
- other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
- case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
- e.exprBuilder.apply(evalIdentifierExpr(e.identifierExpr), e.otherExprs)
- }
+ executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children))
+ case other =>
+ other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
+ case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
+
+ if (referredTempVars.isDefined) {
+ referredTempVars.get ++= collectTemporaryVariablesInExpressionTree(e)
+ }
+
+ e.exprBuilder.apply(evalIdentifierExpr(e.identifierExpr), e.otherExprs)
+ }
+ }
+
+ private def collectTemporaryVariablesInLogicalPlan(child: LogicalPlan): Seq[Seq[String]] = {
+ def collectTempVars(child: LogicalPlan): Seq[Seq[String]] = {
+ child.flatMap { plan =>
+ plan.expressions.flatMap { e => collectTemporaryVariablesInExpressionTree(e) }
+ }.distinct
+ }
+ collectTempVars(child)
+ }
+
+ private def collectTemporaryVariablesInExpressionTree(child: Expression): Seq[Seq[String]] = {
+ def collectTempVars(child: Expression): Seq[Seq[String]] = {
+ child.flatMap { expr =>
+ expr.children.flatMap(_.flatMap {
+ case e: SubqueryExpression => collectTemporaryVariablesInLogicalPlan(e.plan)
+ case r: VariableReference => Seq(r.originalNameParts)
+ case _ => Seq.empty
+ })
+ }.distinct
+ }
+ collectTempVars(child)
}
private def evalIdentifierExpr(expr: Expression): Seq[String] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala
index 7ea90854932e5..f01c00f9fa756 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
-import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, IntegerLiteral, Literal, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE}
@@ -129,27 +129,9 @@ class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends S
groupExprs: Seq[Expression]): Seq[Expression] = {
assert(selectList.forall(_.resolved))
if (isGroupByAll(groupExprs)) {
- val expandedGroupExprs = expandGroupByAll(selectList)
- if (expandedGroupExprs.isEmpty) {
- // Don't replace the ALL when we fail to infer the grouping columns. We will eventually
- // tell the user in checkAnalysis that we cannot resolve the all in group by.
- groupExprs
- } else {
- // This is a valid GROUP BY ALL aggregate.
- expandedGroupExprs.get.zipWithIndex.map { case (expr, index) =>
- trimAliases(expr) match {
- // HACK ALERT: If the expanded grouping expression is an integer literal, don't use it
- // but use an integer literal of the index. The reason is we may repeatedly
- // analyze the plan, and the original integer literal may cause failures
- // with a later GROUP BY ordinal resolution. GROUP BY constant is
- // meaningless so whatever value does not matter here.
- case IntegerLiteral(_) =>
- // GROUP BY ordinal uses 1-based index.
- Literal(index + 1)
- case _ => expr
- }
- }
- }
+ // Don't replace the ALL when we fail to infer the grouping columns. We will eventually tell
+ // the user in checkAnalysis that we cannot resolve the all in group by.
+ expandGroupByAll(selectList).getOrElse(groupExprs)
} else {
groupExprs
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
index 7d2c5b7d2e2d4..6ec465f7ffe76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
@@ -83,6 +83,26 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
cteDef.copy(child = alias.copy(child = loop))
}
+ // Simple case of duplicating (UNION ALL) clause.
+ case alias @ SubqueryAlias(_, withCTE @ WithCTE(
+ Union(Seq(anchor, recursion), false, false), innerCteDefs)) =>
+ if (!anchor.resolved) {
+ cteDef
+ } else {
+ // We need to check whether any of the inner CTEs has a self reference and replace
+ // it if needed
+ val newInnerCteDefs = innerCteDefs.map { innerCteDef =>
+ innerCteDef.copy(child = rewriteRecursiveCTERefs(
+ innerCteDef.child, anchor, cteDef.id, None))
+ }
+ val loop = UnionLoop(
+ cteDef.id,
+ anchor,
+ rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, None))
+ cteDef.copy(child = alias.copy(child = withCTE.copy(
+ plan = loop, cteDefs = newInnerCteDefs)))
+ }
+
// The case of CTE name followed by a parenthesized list of column name(s), eg.
// WITH RECURSIVE t(n).
case alias @ SubqueryAlias(_,
@@ -100,11 +120,38 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
}
- // If the recursion is described with an UNION (deduplicating) clause then the
+ // The case of CTE name followed by a parenthesized list of column name(s), eg.
+ // WITH RECURSIVE t(n).
+ case alias @ SubqueryAlias(_,
+ columnAlias @ UnresolvedSubqueryColumnAliases(
+ colNames,
+ withCTE @ WithCTE(Union(Seq(anchor, recursion), false, false), innerCteDefs)
+ )) =>
+ if (!anchor.resolved) {
+ cteDef
+ } else {
+ // We need to check whether any of the inner CTEs has a self reference and replace
+ // it if needed
+ val newInnerCteDefs = innerCteDefs.map { innerCteDef =>
+ innerCteDef.copy(child = rewriteRecursiveCTERefs(
+ innerCteDef.child, anchor, cteDef.id, Some(colNames)))
+ }
+ val loop = UnionLoop(
+ cteDef.id,
+ anchor,
+ rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)))
+ cteDef.copy(child = alias.copy(child = columnAlias.copy(
+ child = withCTE.copy(plan = loop, cteDefs = newInnerCteDefs))))
+ }
+
+ // If the recursion is described with a UNION (deduplicating) clause then the
// recursive term should not return those rows that have been calculated previously,
// and we exclude those rows from the current iteration result.
case alias @ SubqueryAlias(_,
Distinct(Union(Seq(anchor, recursion), false, false))) =>
+ cteDef.failAnalysis(
+ errorClass = "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE",
+ messageParameters = Map.empty)
if (!anchor.resolved) {
cteDef
} else {
@@ -120,12 +167,43 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
cteDef.copy(child = alias.copy(child = loop))
}
+ // UNION case with CTEs inside.
+ case alias @ SubqueryAlias(_, withCTE @ WithCTE(
+ Distinct(Union(Seq(anchor, recursion), false, false)), innerCteDefs)) =>
+ cteDef.failAnalysis(
+ errorClass = "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE",
+ messageParameters = Map.empty)
+ if (!anchor.resolved) {
+ cteDef
+ } else {
+ // We need to check whether any of the inner CTEs has a self reference and replace
+ // it if needed
+ val newInnerCteDefs = innerCteDefs.map { innerCteDef =>
+ innerCteDef.copy(child = rewriteRecursiveCTERefs(
+ innerCteDef.child, anchor, cteDef.id, None))
+ }
+ val loop = UnionLoop(
+ cteDef.id,
+ Distinct(anchor),
+ Except(
+ rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, None),
+ UnionLoopRef(cteDef.id, anchor.output, true),
+ isAll = false
+ )
+ )
+ cteDef.copy(child = alias.copy(child = withCTE.copy(
+ plan = loop, cteDefs = newInnerCteDefs)))
+ }
+
// The case of CTE name followed by a parenthesized list of column name(s).
case alias @ SubqueryAlias(_,
columnAlias@UnresolvedSubqueryColumnAliases(
colNames,
Distinct(Union(Seq(anchor, recursion), false, false))
)) =>
+ cteDef.failAnalysis(
+ errorClass = "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE",
+ messageParameters = Map.empty)
if (!anchor.resolved) {
cteDef
} else {
@@ -141,6 +219,37 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
cteDef.copy(child = alias.copy(child = columnAlias.copy(child = loop)))
}
+ // The case of CTE name followed by a parenthesized list of column name(s).
+ case alias @ SubqueryAlias(_,
+ columnAlias@UnresolvedSubqueryColumnAliases(
+ colNames,
+ WithCTE(Distinct(Union(Seq(anchor, recursion), false, false)), innerCteDefs)
+ )) =>
+ cteDef.failAnalysis(
+ errorClass = "UNION_NOT_SUPPORTED_IN_RECURSIVE_CTE",
+ messageParameters = Map.empty)
+ if (!anchor.resolved) {
+ cteDef
+ } else {
+ // We need to check whether any of the inner CTEs has a self reference and replace
+ // it if needed
+ val newInnerCteDefs = innerCteDefs.map { innerCteDef =>
+ innerCteDef.copy(child = rewriteRecursiveCTERefs(
+ innerCteDef.child, anchor, cteDef.id, Some(colNames)))
+ }
+ val loop = UnionLoop(
+ cteDef.id,
+ Distinct(anchor),
+ Except(
+ rewriteRecursiveCTERefs(recursion, anchor, cteDef.id, Some(colNames)),
+ UnionLoopRef(cteDef.id, anchor.output, true),
+ isAll = false
+ )
+ )
+ cteDef.copy(child = alias.copy(child = columnAlias.copy(
+ child = withCTE.copy(plan = loop, cteDefs = newInnerCteDefs))))
+ }
+
case other =>
// We do not support cases of sole Union (needs a SubqueryAlias above it), nor
// Project (as UnresolvedSubqueryColumnAliases have not been substituted with the
@@ -187,22 +296,6 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
}
}
- /**
- * Checks if there is any self-reference within subqueries and throws an error
- * if that is the case.
- */
- def checkForSelfReferenceInSubquery(plan: LogicalPlan): Unit = {
- plan.subqueriesAll.foreach { subquery =>
- subquery.foreach {
- case r: CTERelationRef if r.recursive =>
- throw new AnalysisException(
- errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
- messageParameters = Map.empty)
- case _ =>
- }
- }
- }
-
/**
* Counts number of self-references in a recursive CTE definition and throws an error
* if that number is bigger than 1.
@@ -230,6 +323,9 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
case Join(left, right, Inner, _, _) =>
checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
+ case Join(left, right, Cross, _, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+ checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
case Join(left, right, LeftOuter, _, _) =>
checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
deleted file mode 100644
index fa08ae61daec4..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.analysis
-
-import org.apache.spark.sql.catalyst.expressions.{BaseGroupingSets, Expression, Literal, SortOrder}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
-import org.apache.spark.sql.catalyst.trees.TreePattern._
-import org.apache.spark.sql.types.IntegerType
-
-/**
- * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression.
- */
-object SubstituteUnresolvedOrdinals extends Rule[LogicalPlan] {
- private def containIntLiteral(e: Expression): Boolean = e match {
- case Literal(_, IntegerType) => true
- case gs: BaseGroupingSets => gs.children.exists(containIntLiteral)
- case _ => false
- }
-
- private def substituteUnresolvedOrdinal(expression: Expression): Expression = expression match {
- case ordinal @ Literal(index: Int, IntegerType) =>
- withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
- case e => e
- }
-
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
- t => t.containsPattern(LITERAL) && t.containsAnyPattern(AGGREGATE, SORT), ruleId) {
- case s: Sort if conf.orderByOrdinal && s.order.exists(o => containIntLiteral(o.child)) =>
- val newOrders = s.order.map {
- case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
- val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
- withOrigin(order.origin)(order.copy(child = newOrdinal))
- case other => other
- }
- withOrigin(s.origin)(s.copy(order = newOrders))
-
- case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(containIntLiteral) =>
- val newGroups = a.groupingExpressions.map {
- case ordinal @ Literal(index: Int, IntegerType) =>
- withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
- case gs: BaseGroupingSets =>
- withOrigin(gs.origin)(gs.withNewChildren(gs.children.map(substituteUnresolvedOrdinal)))
- case other => other
- }
- withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
index 4a6504666d41f..adf74c489ce1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
@@ -80,7 +80,6 @@ object TableOutputResolver extends SQLConfHelper with Logging {
query: LogicalPlan,
byName: Boolean,
conf: SQLConf,
- // TODO: Only DS v1 writing will set it to true. We should enable in for DS v2 as well.
supportColDefaultValue: Boolean = false): LogicalPlan = {
if (expected.size < query.output.size) {
@@ -460,11 +459,28 @@ object TableOutputResolver extends SQLConfHelper with Logging {
}
if (resKey.length == 1 && resValue.length == 1) {
- val keyFunc = LambdaFunction(resKey.head, Seq(keyParam))
- val valueFunc = LambdaFunction(resValue.head, Seq(valueParam))
- val newKeys = ArrayTransform(MapKeys(nullCheckedInput), keyFunc)
- val newValues = ArrayTransform(MapValues(nullCheckedInput), valueFunc)
- Some(Alias(MapFromArrays(newKeys, newValues), expected.name)())
+ // If the key and value expressions have not changed, we just check original map field.
+ // Otherwise, we construct a new map by adding transformations to the keys and values.
+ if (resKey.head == keyParam && resValue.head == valueParam) {
+ Some(
+ Alias(nullCheckedInput, expected.name)(
+ nonInheritableMetadataKeys =
+ Seq(CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)))
+ } else {
+ val newKeys = if (resKey.head != keyParam) {
+ val keyFunc = LambdaFunction(resKey.head, Seq(keyParam))
+ ArrayTransform(MapKeys(nullCheckedInput), keyFunc)
+ } else {
+ MapKeys(nullCheckedInput)
+ }
+ val newValues = if (resValue.head != valueParam) {
+ val valueFunc = LambdaFunction(resValue.head, Seq(valueParam))
+ ArrayTransform(MapValues(nullCheckedInput), valueFunc)
+ } else {
+ MapValues(nullCheckedInput)
+ }
+ Some(Alias(MapFromArrays(newKeys, newValues), expected.name)())
+ }
} else {
None
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 4769970b51421..f68084803fe75 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -90,7 +90,7 @@ object TypeCoercion extends TypeCoercionBase {
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
- case (d1: DatetimeType, d2: DatetimeType) => Some(findWiderDateTimeType(d1, d2))
+ case (d1: DatetimeType, d2: DatetimeType) => findWiderDateTimeType(d1, d2)
case (t1: DayTimeIntervalType, t2: DayTimeIntervalType) =>
Some(DayTimeIntervalType(t1.startField.min(t2.startField), t1.endField.max(t2.endField)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala
index 294ee93a3c7bc..eae7d5a74dbc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{
Unpivot
}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.types.DataType
@@ -142,7 +143,9 @@ abstract class TypeCoercionBase extends TypeCoercionHelper {
case s @ Except(left, right, isAll)
if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
- val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
+ val newChildren: Seq[LogicalPlan] = withOrigin(s.origin) {
+ buildNewChildrenWithWiderTypes(left :: right :: Nil)
+ }
if (newChildren.isEmpty) {
s -> Nil
} else {
@@ -154,7 +157,9 @@ abstract class TypeCoercionBase extends TypeCoercionHelper {
case s @ Intersect(left, right, isAll)
if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
- val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
+ val newChildren: Seq[LogicalPlan] = withOrigin(s.origin) {
+ buildNewChildrenWithWiderTypes(left :: right :: Nil)
+ }
if (newChildren.isEmpty) {
s -> Nil
} else {
@@ -166,7 +171,9 @@ abstract class TypeCoercionBase extends TypeCoercionHelper {
case s: Union
if s.childrenResolved && !s.byName &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
- val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
+ val newChildren: Seq[LogicalPlan] = withOrigin(s.origin) {
+ buildNewChildrenWithWiderTypes(s.children)
+ }
if (newChildren.isEmpty) {
s -> Nil
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala
index 96053904e2fb2..390ff2f3114d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala
@@ -84,7 +84,8 @@ import org.apache.spark.sql.types.{
StructType,
TimestampNTZType,
TimestampType,
- TimestampTypeExpression
+ TimestampTypeExpression,
+ TimeType
}
abstract class TypeCoercionHelper {
@@ -239,16 +240,18 @@ abstract class TypeCoercionHelper {
}
}
- protected def findWiderDateTimeType(d1: DatetimeType, d2: DatetimeType): DatetimeType =
+ protected def findWiderDateTimeType(d1: DatetimeType, d2: DatetimeType): Option[DatetimeType] =
(d1, d2) match {
+ case (_, _: TimeType) => None
+ case (_: TimeType, _) => None
case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
- TimestampType
+ Some(TimestampType)
case (_: TimestampType, _: TimestampNTZType) | (_: TimestampNTZType, _: TimestampType) =>
- TimestampType
+ Some(TimestampType)
case (_: TimestampNTZType, _: DateType) | (_: DateType, _: TimestampNTZType) =>
- TimestampNTZType
+ Some(TimestampNTZType)
}
/**
@@ -283,7 +286,11 @@ abstract class TypeCoercionHelper {
case (e, _) => e
}
- InSubquery(newLhs, l.withNewPlan(Project(castedRhs, l.plan)))
+ if (newLhs != lhs || castedRhs != rhs) {
+ InSubquery(newLhs, l.withNewPlan(Project(castedRhs, l.plan)))
+ } else {
+ i
+ }
} else {
i
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala
new file mode 100644
index 0000000000000..24097c55895e2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, SetOperation, Union}
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.DataType
+
+object TypeCoercionValidation extends QueryErrorsBase {
+ private val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Unit]("dataTypeMismatchError")
+
+ def failOnTypeCheckResult(e: Expression, operator: Option[LogicalPlan] = None): Nothing = {
+ e.checkInputDataTypes() match {
+ case checkRes: TypeCheckResult.DataTypeMismatch =>
+ e.setTagValue(DATA_TYPE_MISMATCH_ERROR, ())
+ e.dataTypeMismatch(e, checkRes)
+ case TypeCheckResult.TypeCheckFailure(message) =>
+ e.setTagValue(DATA_TYPE_MISMATCH_ERROR, ())
+ val extraHint = operator.map(getHintForExpressionCoercion(_)).getOrElse("")
+ e.failAnalysis(
+ errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ messageParameters = Map("sqlExpr" -> toSQLExpr(e), "msg" -> message, "hint" -> extraHint)
+ )
+ case checkRes: TypeCheckResult.InvalidFormat =>
+ e.invalidFormat(checkRes)
+ }
+ }
+
+ def getHintForOperatorCoercion(plan: LogicalPlan): String = {
+ if (!SQLConf.get.ansiEnabled) {
+ ""
+ } else {
+ val nonAnsiPlan = getDefaultTypeCoercionPlan(plan)
+ var issueFixedIfAnsiOff = true
+ nonAnsiPlan match {
+ case _: Union | _: SetOperation if nonAnsiPlan.children.length > 1 =>
+ def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
+
+ val ref = dataTypes(nonAnsiPlan.children.head)
+ val dataTypesAreCompatibleFn = getDataTypesAreCompatibleFn(nonAnsiPlan)
+ nonAnsiPlan.children.tail.zipWithIndex.foreach {
+ case (child, ti) =>
+ // Check if the data types match.
+ dataTypes(child).zip(ref).zipWithIndex.foreach {
+ case ((dt1, dt2), ci) =>
+ if (!dataTypesAreCompatibleFn(dt1, dt2)) {
+ issueFixedIfAnsiOff = false
+ }
+ }
+ }
+
+ case _ =>
+ }
+ extraHintMessage(issueFixedIfAnsiOff)
+ }
+ }
+
+ def getDataTypesAreCompatibleFn(plan: LogicalPlan): (DataType, DataType) => Boolean = {
+ val isUnion = plan.isInstanceOf[Union]
+ if (isUnion) { (dt1: DataType, dt2: DataType) =>
+ DataType.equalsStructurally(dt1, dt2, true)
+ } else {
+ // SPARK-18058: we shall not care about the nullability of columns
+ (dt1: DataType, dt2: DataType) =>
+ TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).nonEmpty
+ }
+ }
+
+ private def getHintForExpressionCoercion(plan: LogicalPlan): String = {
+ if (!SQLConf.get.ansiEnabled) {
+ ""
+ } else {
+ val nonAnsiPlan = getDefaultTypeCoercionPlan(plan)
+ var issueFixedIfAnsiOff = true
+ getAllExpressions(nonAnsiPlan).foreach(_.foreachUp {
+ case e: Expression
+ if e.getTagValue(DATA_TYPE_MISMATCH_ERROR).isDefined &&
+ e.checkInputDataTypes().isFailure =>
+ e.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(_) | _: TypeCheckResult.DataTypeMismatch =>
+ issueFixedIfAnsiOff = false
+ }
+
+ case _ =>
+ })
+ extraHintMessage(issueFixedIfAnsiOff)
+ }
+ }
+
+ private def getDefaultTypeCoercionPlan(plan: LogicalPlan): LogicalPlan =
+ TypeCoercion.typeCoercionRules.foldLeft(plan) { case (p, rule) => rule(p) }
+
+ private def extraHintMessage(issueFixedIfAnsiOff: Boolean): String = {
+ if (issueFixedIfAnsiOff) {
+ "\nTo fix the error, you might need to add explicit type casts. If necessary set " +
+ s"${SQLConf.ANSI_ENABLED.key} to false to bypass this error."
+ } else {
+ ""
+ }
+ }
+
+ private def getAllExpressions(plan: LogicalPlan): Seq[Expression] = {
+ plan match {
+ // We only resolve `groupingExpressions` if `aggregateExpressions` is resolved first (See
+ // `ResolveReferencesInAggregate`). We should check errors in `aggregateExpressions` first.
+ case a: Aggregate => a.aggregateExpressions ++ a.groupingExpressions
+ case _ => plan.expressions
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index f7ab41bd6f96c..146032f1d199e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -105,7 +105,7 @@ object UnsupportedOperationChecker extends Logging {
case d: Deduplicate if d.isStreaming && d.keys.exists(hasEventTimeCol) => true
case d: DeduplicateWithinWatermark if d.isStreaming => true
case t: TransformWithState if t.isStreaming => true
- case t: TransformWithStateInPandas if t.isStreaming => true
+ case t: TransformWithStateInPySpark if t.isStreaming => true
case _ => false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ValidateSubqueryExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ValidateSubqueryExpression.scala
index cc863d4c03ef6..d6b7a4dccb907 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ValidateSubqueryExpression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ValidateSubqueryExpression.scala
@@ -400,6 +400,10 @@ object ValidateSubqueryExpression
&& SQLConf.get.getConf(SQLConf.DECORRELATE_SET_OPS_ENABLED))
p.children.foreach(child => checkPlan(child, aggregated, childCanContainOuter))
+ // SQLFunctionNode serves as a container for the underlying SQL function plan.
+ case s: SQLFunctionNode =>
+ s.children.foreach(child => checkPlan(child, aggregated, canContainOuter))
+
// Category 2:
// These operators can be anywhere in a correlated subquery.
// so long as they do not host outer references in the operators.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala
index ccbb82a0bac34..63642e3ab1f10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala
@@ -17,15 +17,19 @@
package org.apache.spark.sql.catalyst.analysis.resolver
+import java.util.IdentityHashMap
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{
- AnalysisErrorAt,
AnsiTypeCoercion,
CollationTypeCoercion,
TypeCoercion
}
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, OuterReference, SubExprUtils}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Sort}
+import org.apache.spark.sql.errors.QueryCompilationErrors
/**
* A resolver for [[AggregateExpression]]s which are introduced while resolving an
@@ -37,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Project}
* in the [[FunctionResolver]].
*/
class AggregateExpressionResolver(
+ operatorResolver: Resolver,
expressionResolver: ExpressionResolver,
timezoneAwareExpressionResolver: TimezoneAwareExpressionResolver)
extends TreeNodeResolver[AggregateExpression, Expression]
@@ -53,85 +58,175 @@ class AggregateExpressionResolver(
private val expressionResolutionContextStack =
expressionResolver.getExpressionResolutionContextStack
+ private val subqueryRegistry = operatorResolver.getSubqueryRegistry
+ private val traversals = expressionResolver.getExpressionTreeTraversals
+ private val autoGeneratedAliasProvider = new AutoGeneratedAliasProvider(
+ expressionResolver.getExpressionIdAssigner
+ )
/**
- * Resolves the given [[AggregateExpression]] by applying:
- * - Type coercion rules
- * - Validity checks. Those include:
- * - Whether the [[AggregateExpression]] is under a valid operator.
- * - Whether there is a nested [[AggregateExpression]].
- * - Whether there is a nondeterministic child.
- * - Updates to the [[ExpressionResolver.expressionResolutionContextStack]]
+ * Resolves the given [[AggregateExpression]]:
+ * - Apply type coercion rules;
+ * - Validate the [[AggregateExpression]]:
+ * 1. Nested aggregate functions are not allowed;
+ * 2. Nondeterministic expressions in the subtree of a related aggregate function are not
+ * allowed;
+ * 3. The mix of outer and local references is not allowed.
+ * - Update the [[ExpressionResolver.expressionResolutionContextStack]];
+ * - Handle the outer aggregate expression in a special way (see
+ * [[handleOuterAggregateExpression]]).
*/
override def resolve(aggregateExpression: AggregateExpression): Expression = {
+ val expressionResolutionContext = expressionResolutionContextStack.peek()
+
val aggregateExpressionWithTypeCoercion =
- withResolvedChildren(aggregateExpression, typeCoercionResolver.resolve)
+ withResolvedChildren(aggregateExpression, typeCoercionResolver.resolve _)
+ .asInstanceOf[AggregateExpression]
- throwIfNotUnderValidOperator(aggregateExpression)
- throwIfNestedAggregateExists(aggregateExpressionWithTypeCoercion)
- throwIfHasNondeterministicChildren(aggregateExpressionWithTypeCoercion)
+ validateResolvedAggregateExpression(aggregateExpressionWithTypeCoercion)
- expressionResolutionContextStack
- .peek()
- .hasAggregateExpressionsInASubtree = true
+ expressionResolutionContext.hasAggregateExpressions = true
// There are two different cases that we handle regarding the value of the flag:
//
// - We have an attribute under an `AggregateExpression`:
// {{{ SELECT COUNT(col1) FROM VALUES (1); }}}
- // In this case, value of the `hasAttributeInASubtree` flag should be `false` as it
- // indicates whether there is an attribute in the subtree that's not `AggregateExpression`
- // so we can throw the `MISSING_GROUP_BY` exception appropriately.
+ // In this case, value of the `hasAttributeOutsideOfAggregateExpressions` flag should be
+ // `false` as it indicates whether there is an attribute in the subtree that's not
+ // `AggregateExpression` so we can throw the `MISSING_GROUP_BY` exception appropriately.
//
// - In the following example:
// {{{ SELECT COUNT(*), col1 + 1 FROM VALUES (1); }}}
// It would be `true` as described above.
- expressionResolutionContextStack.peek().hasAttributeInASubtree = false
+ expressionResolutionContext.hasAttributeOutsideOfAggregateExpressions = false
- aggregateExpressionWithTypeCoercion
+ if (expressionResolutionContext.hasOuterReferences) {
+ handleOuterAggregateExpression(aggregateExpressionWithTypeCoercion)
+ } else {
+ traversals.current.parentOperator match {
+ case Sort(_, _, aggregate: Aggregate, _) =>
+ handleAggregateExpressionInSort(aggregateExpressionWithTypeCoercion, aggregate)
+ case other =>
+ aggregateExpressionWithTypeCoercion
+ }
+ }
}
- private def throwIfNotUnderValidOperator(aggregateExpression: AggregateExpression): Unit = {
- expressionResolver.getParentOperator.get match {
- case _: Aggregate | _: Project =>
- case filter: Filter =>
- filter.failAnalysis(
- errorClass = "INVALID_WHERE_CONDITION",
- messageParameters = Map(
- "condition" -> toSQLExpr(filter.condition),
- "expressionList" -> Seq(aggregateExpression).mkString(", ")
- )
- )
- case other =>
- other.failAnalysis(
- errorClass = "UNSUPPORTED_EXPR_FOR_OPERATOR",
- messageParameters = Map(
- "invalidExprSqls" -> Seq(aggregateExpression).mkString(", ")
- )
- )
+ private def validateResolvedAggregateExpression(
+ aggregateExpression: AggregateExpression): Unit = {
+ if (expressionResolutionContextStack.peek().hasAggregateExpressions) {
+ throwNestedAggregateFunction(aggregateExpression)
+ }
+
+ val nonDeterministicChild =
+ aggregateExpression.aggregateFunction.children.collectFirst {
+ case child if !child.deterministic => child
+ }
+ if (nonDeterministicChild.nonEmpty) {
+ throwAggregateFunctionWithNondeterministicExpression(
+ aggregateExpression,
+ nonDeterministicChild.get
+ )
}
}
- private def throwIfNestedAggregateExists(aggregateExpression: AggregateExpression): Unit = {
- if (expressionResolutionContextStack
- .peek()
- .hasAggregateExpressionsInASubtree) {
- aggregateExpression.failAnalysis(
- errorClass = "NESTED_AGGREGATE_FUNCTION",
- messageParameters = Map.empty
+ /**
+ * If the [[AggregateExpression]] has outer references in its subtree, we need to handle it in a
+ * special way. The whole process is explained in the [[SubqueryScope]] scaladoc, but in short
+ * we need to:
+ * - Validate that we don't have local references in this subtree;
+ * - Create a new subtree without [[OuterReference]]s;
+ * - Alias this subtree and put it inside the current [[SubqueryScope]];
+ * - If outer aggregates are allowed, replace the [[AggregateExpression]] with an
+ * [[OuterReference]] to the auto-generated [[Alias]] that we created. This alias will later
+ * be injected into the outer [[Aggregate]];
+ * - In case we have an [[AggregateExpression]] inside a [[Sort]] operator, we need to handle it
+ * in a special way (see [[handleAggregateExpressionInSort]] for more details).
+ * - Return the original [[AggregateExpression]] otherwise. This is done to stay compatible
+ * with the fixed-point Analyzer - a proper exception will be thrown later by
+ * [[ValidateSubqueryExpression]].
+ */
+ private def handleOuterAggregateExpression(
+ aggregateExpression: AggregateExpression): Expression = {
+ if (expressionResolutionContextStack.peek().hasLocalReferences) {
+ throw QueryCompilationErrors.mixedRefsInAggFunc(
+ aggregateExpression.sql,
+ aggregateExpression.origin
+ )
+ }
+
+ if (subqueryRegistry.currentScope.isOuterAggregateAllowed) {
+ val aggregateExpressionWithStrippedOuterReferences =
+ SubExprUtils.stripOuterReference(aggregateExpression)
+
+ val outerAggregateExpressionAlias = autoGeneratedAliasProvider.newOuterAlias(
+ child = aggregateExpressionWithStrippedOuterReferences
)
+ subqueryRegistry.currentScope.addOuterAggregateExpression(
+ outerAggregateExpressionAlias,
+ aggregateExpressionWithStrippedOuterReferences
+ )
+
+ OuterReference(outerAggregateExpressionAlias.toAttribute)
+ } else {
+ aggregateExpression
}
}
- private def throwIfHasNondeterministicChildren(aggregateExpression: AggregateExpression): Unit = {
- aggregateExpression.aggregateFunction.children.foreach(child => {
- if (!child.deterministic) {
- child.failAnalysis(
- errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
- messageParameters = Map("sqlExpr" -> toSQLExpr(aggregateExpression))
- )
+ /**
+ * If we order by an [[AggregateExpression]] which is not present in the [[Aggregate]] operator
+ * (child of the [[Sort]]) we have to extract it (by adding it to the
+ * `extractedAggregateExpressionAliases` list of the current expression tree traversal) and add
+ * it to the [[Aggregate]] operator afterwards (this is done in the [[SortResolver]]).
+ */
+ private def handleAggregateExpressionInSort(
+ aggregateExpression: Expression,
+ aggregate: Aggregate): Expression = {
+ val aliasChildToAliasInAggregateExpressions = new IdentityHashMap[Expression, Alias]
+ val aggregateExpressionsSemanticComparator = new SemanticComparator(
+ aggregate.aggregateExpressions.collect {
+ case alias: Alias =>
+ aliasChildToAliasInAggregateExpressions.put(alias.child, alias)
+ alias.child
}
- })
+ )
+
+ val referencedAggregateExpression =
+ aggregateExpressionsSemanticComparator.collectFirst(aggregateExpression)
+
+ referencedAggregateExpression match {
+ case Some(expression) =>
+ aliasChildToAliasInAggregateExpressions.get(expression) match {
+ case null =>
+ throw SparkException.internalError(
+ s"No parent alias for expression $expression while extracting aggregate" +
+ s"expressions in Sort operator."
+ )
+ case alias: Alias => alias.toAttribute
+ }
+ case None =>
+ val alias = autoGeneratedAliasProvider.newAlias(child = aggregateExpression)
+ traversals.current.extractedAggregateExpressionAliases.add(alias)
+ alias.toAttribute
+ }
+ }
+
+ private def throwNestedAggregateFunction(aggregateExpression: AggregateExpression): Nothing = {
+ throw new AnalysisException(
+ errorClass = "NESTED_AGGREGATE_FUNCTION",
+ messageParameters = Map.empty,
+ origin = aggregateExpression.origin
+ )
+ }
+
+ private def throwAggregateFunctionWithNondeterministicExpression(
+ aggregateExpression: AggregateExpression,
+ nonDeterministicChild: Expression): Nothing = {
+ throw new AnalysisException(
+ errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
+ messageParameters = Map("sqlExpr" -> toSQLExpr(aggregateExpression)),
+ origin = nonDeterministicChild.origin
+ )
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala
new file mode 100644
index 0000000000000..f39e036807b6e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import java.util.LinkedHashMap
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.catalyst.analysis.{
+ withPosition,
+ AnalysisErrorAt,
+ NondeterministicExpressionCollection,
+ UnresolvedAttribute
+}
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ Expression,
+ ExprUtils,
+ IntegerLiteral,
+ Literal,
+ NamedExpression
+}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.IntegerType
+
+/**
+ * Resolves an [[Aggregate]] by resolving its child, aggregate expressions and grouping
+ * expressions. Updates the [[NameScopeStack]] with its output and performs validation
+ * related to [[Aggregate]] resolution.
+ */
+class AggregateResolver(operatorResolver: Resolver, expressionResolver: ExpressionResolver)
+ extends TreeNodeResolver[Aggregate, Aggregate] {
+ private val scopes = operatorResolver.getNameScopes
+
+ /**
+ * Resolve [[Aggregate]] operator.
+ *
+ * 1. Resolve the child (inline table).
+ * 2. Resolve aggregate expressions using [[ExpressionResolver.resolveAggregateExpressions]].
+ * 3. If there's just one [[UnresolvedAttribute]] with a single-part name "ALL", expand it using
+ * aggregate expressions which don't contain aggregate functions. There should not exist a
+ * column with that name in the lower operator's output, otherwise it takes precedence.
+ * 4. Resolve grouping expressions using [[ExpressionResolver.resolveGroupingExpressions]]. This
+ * includes alias references to aggregate expressions, which is done in
+ * [[NameScope.resolveMultipartName]].
+ * 5. Substitute ordinals with aggregate expressions in appropriate places.
+ * 6. Substitute non-deterministic expressions with derived attribute references to an
+ * artificial [[Project]] list.
+ *
+ * At the end of resolution we validate the [[Aggregate]] using the
+ * [[ExprUtils.assertValidAggregation]], update the `scopes` with the output of [[Aggregate]] and
+ * return the result.
+ */
+ def resolve(unresolvedAggregate: Aggregate): Aggregate = {
+ val resolvedAggregate = scopes.withNewScope() {
+ val resolvedChild = operatorResolver.resolve(unresolvedAggregate.child)
+
+ val resolvedAggregateExpressions = expressionResolver.resolveAggregateExpressions(
+ unresolvedAggregate.aggregateExpressions,
+ unresolvedAggregate
+ )
+
+ val resolvedGroupingExpressions =
+ if (canGroupByAll(unresolvedAggregate.groupingExpressions)) {
+ tryResolveGroupByAll(
+ resolvedAggregateExpressions,
+ unresolvedAggregate
+ )
+ } else {
+ val partiallyResolvedGroupingExpressions = expressionResolver.resolveGroupingExpressions(
+ unresolvedAggregate.groupingExpressions,
+ unresolvedAggregate
+ )
+ withPosition(unresolvedAggregate) {
+ tryReplaceOrdinalsInGroupingExpressions(
+ partiallyResolvedGroupingExpressions,
+ resolvedAggregateExpressions
+ )
+ }
+ }
+
+ val resolvedAggregate = unresolvedAggregate.copy(
+ groupingExpressions = resolvedGroupingExpressions,
+ aggregateExpressions = resolvedAggregateExpressions.expressions,
+ child = resolvedChild
+ )
+
+ tryPullOutNondeterministic(resolvedAggregate)
+ }
+
+ // TODO: This validation function does a post-traversal. This is discouraged in single-pass
+ // Analyzer.
+ ExprUtils.assertValidAggregation(resolvedAggregate)
+
+ scopes.overwriteOutputAndExtendHiddenOutput(
+ output = resolvedAggregate.aggregateExpressions.map(_.toAttribute)
+ )
+
+ resolvedAggregate
+ }
+
+ /**
+ * Replaces the ordinals with the actual expressions from the resolved aggregate expression list
+ * or throws if any of aggregate expression are irregular.
+ * There are three cases:
+ * - If the aggregate expression referenced by the ordinal is a [[Literal]] with the Integer
+ * data type - preserve the ordinal literal in order to pass logical plan comparison.
+ * - If [[SQLConf.groupByOrdinal]] flag is set to false, treat the grouping expression as
+ * a [[Literal]] instead of ordinal.
+ * - If aggregate expression is an [[Alias]] return [[Alias.child]].
+ * - Otherwise, replace the ordinal with the aggregate expression.
+ * Remove all the leftover [[Alias]]es at the end of resolution.
+ * For the query:
+ *
+ * {{{ SELECT col1 + 1, col1, col2 FROM VALUES(1, 2) GROUP BY 2, col2, 3; }}}
+ *
+ * It would replace `2` with the `col1` and `3` with `col2` so the final grouping expression list
+ * would be: [col1, col2, col2].
+ *
+ * In case of having an integer [[Literal]] in the aggregate expressions which is referenced by an
+ * ordinal, example and final grouping expression list are the following:
+ * - Example:
+ * {{{ SELECT col1 + 1, col1, 10 FROM VALUES(1, 2) GROUP BY 2, col2, 3; }}}
+ * - Grouping expressions:
+ * [col1, col2, 3] // we preserve the ordinal instead of replacing it.
+ */
+ private def tryReplaceOrdinalsInGroupingExpressions(
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: ResolvedAggregateExpressions): Seq[Expression] = {
+ val aggregateExpressionsArray = aggregateExpressions.expressions.toArray
+ val groupByOrdinal = conf.groupByOrdinal
+ groupingExpressions.map { expression =>
+ val maybeGroupByOrdinal = if (groupByOrdinal) {
+ tryReplaceOrdinalInGroupingExpression(
+ expression,
+ aggregateExpressionsArray,
+ aggregateExpressions
+ )
+ } else {
+ expression
+ }
+ maybeGroupByOrdinal match {
+ case alias: Alias =>
+ alias.child
+ case other => other
+ }
+ }
+ }
+
+ private def tryReplaceOrdinalInGroupingExpression(
+ groupingExpression: Expression,
+ aggregateExpressionsArray: Array[NamedExpression],
+ resolvedAggregateExpressions: ResolvedAggregateExpressions): Expression = {
+ TryExtractOrdinal(groupingExpression) match {
+ case Some(ordinal) =>
+ if (ordinal > aggregateExpressionsArray.length) {
+ throw QueryCompilationErrors.groupByPositionRangeError(
+ ordinal,
+ aggregateExpressionsArray.length
+ )
+ }
+
+ if (resolvedAggregateExpressions.hasStar) {
+ throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError()
+ }
+
+ if (resolvedAggregateExpressions.expressionIndexesWithAggregateFunctions
+ .contains(ordinal - 1)) {
+ throw QueryCompilationErrors.groupByPositionRefersToAggregateFunctionError(
+ ordinal,
+ aggregateExpressionsArray(ordinal - 1)
+ )
+ }
+
+ val resolvedOrdinalAggregateExpression =
+ aggregateExpressionsArray(ordinal - 1) match {
+ case alias: Alias =>
+ alias.child
+ case other => other
+ }
+
+ resolvedOrdinalAggregateExpression match {
+ case Literal(_: Int, IntegerType) =>
+ Literal(ordinal)
+ case other => other
+ }
+ case None => groupingExpression
+ }
+ }
+
+ /**
+ * Resolve `GROUP BY ALL`.
+ *
+ * Examples below show which queries should be resolved with `tryResolveGroupByAll` and which
+ * should be resolved generically (using the [[ExpressionResolver.resolveGroupingExpressions]]):
+ *
+ * Example 1:
+ *
+ * {{{
+ * -- Table `table_1` has a column `all`.
+ * SELECT * from table_1 GROUP BY all;
+ * }}}
+ * this one should be grouped by the column `all`.
+ *
+ * Example 2:
+ *
+ * {{{
+ * -- Table `table_2` doesn't have a column `all`.
+ * SELECT * from table_2 GROUP BY all;
+ * }}}
+ * this one should be grouped by all the columns from `table_1`.
+ *
+ * Example 3:
+ *
+ * {{{
+ * -- Table `table_3` doesn't have a column `all` and there other grouping expressions.
+ * SELECT * from table_3 GROUP BY all, column;
+ * }}}
+ * this one should be grouped by column `all` which doesn't exist so `UNRESOLVED_COLUMN`
+ * exception is thrown.
+ *
+ * Example 4:
+ *
+ * {{{ SELECT col1, col2 + 1, COUNT(col1 + 1) FROM VALUES(1, 2) GROUP BY ALL; }}}
+ * this one should be grouped by keyword `ALL`. It means that the grouping expressions list is
+ * going to contain all the aggregate expressions that don't have aggregate expressions in their
+ * subtrees. The grouping expressions list will be [col1, col2 + 1], and COUNT(col1 + 1) won't be
+ * included, being an [[AggregateExpression]].
+ *
+ * Example 5:
+ *
+ * {{{ SELECT col1, 5 FROM VALUES(1) GROUP BY ALL; }}}
+ * this one should be grouped by keyword `ALL`. If there is an aggregate expression which is a
+ * [[Literal]] with the Integer data type - preserve the ordinal literal in order to pass logical
+ * plan comparison. The grouping expressions list will be [col1, 2].
+ */
+ private def tryResolveGroupByAll(
+ aggregateExpressions: ResolvedAggregateExpressions,
+ aggregate: Aggregate): Seq[Expression] = {
+ if (aggregateExpressions.resolvedExpressionsWithoutAggregates.isEmpty &&
+ aggregateExpressions.hasAttributeOutsideOfAggregateExpressions) {
+ aggregate.failAnalysis(
+ errorClass = "UNRESOLVED_ALL_IN_GROUP_BY",
+ messageParameters = Map.empty
+ )
+ }
+
+ aggregateExpressions.resolvedExpressionsWithoutAggregates.zipWithIndex.map {
+ case (expression, index) =>
+ expression match {
+ case IntegerLiteral(_) =>
+ Literal(index + 1)
+ case _ => expression
+ }
+ }
+ }
+
+ /**
+ * In case there are non-deterministic expressions in either `groupingExpressions` or
+ * `aggregateExpressions` replace them with attributes created out of corresponding
+ * non-deterministic expression. Example:
+ *
+ * {{{ SELECT RAND() GROUP BY 1; }}}
+ *
+ * This query would have the following analyzed plan:
+ * Aggregate(
+ * groupingExpressions = [AttributeReference(_nonDeterministic)]
+ * aggregateExpressions = [Alias(AttributeReference(_nonDeterministic), `rand()`)]
+ * child = Project(
+ * projectList = [Alias(Rand(...), `_nondeterministic`)]
+ * child = OneRowRelation
+ * )
+ * )
+ */
+ private def tryPullOutNondeterministic(aggregate: Aggregate): Aggregate = {
+ val nondeterministicToAttributes: LinkedHashMap[Expression, NamedExpression] =
+ NondeterministicExpressionCollection.getNondeterministicToAttributes(
+ aggregate.groupingExpressions
+ )
+
+ if (!nondeterministicToAttributes.isEmpty) {
+ val newChild = Project(
+ scopes.current.output ++ nondeterministicToAttributes.values.asScala.toSeq,
+ aggregate.child
+ )
+ val resolvedAggregateExpressions = aggregate.aggregateExpressions.map { expression =>
+ PullOutNondeterministicExpressionInExpressionTree(expression, nondeterministicToAttributes)
+ }
+ val resolvedGroupingExpressions = aggregate.groupingExpressions.map { expression =>
+ PullOutNondeterministicExpressionInExpressionTree(
+ expression,
+ nondeterministicToAttributes
+ )
+ }
+ aggregate.copy(
+ groupingExpressions = resolvedGroupingExpressions,
+ aggregateExpressions = resolvedAggregateExpressions,
+ child = newChild
+ )
+ } else {
+ aggregate
+ }
+ }
+
+ private def canGroupByAll(expressions: Seq[Expression]): Boolean = {
+ val isOrderByAll = expressions match {
+ case Seq(unresolvedAttribute: UnresolvedAttribute) =>
+ unresolvedAttribute.equalsIgnoreCase("ALL")
+ case _ => false
+ }
+ isOrderByAll && scopes.current
+ .resolveMultipartName(Seq("ALL"))
+ .candidates
+ .isEmpty
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala
index ed013232ac84d..06b340738d320 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala
@@ -35,9 +35,10 @@ class AliasResolver(expressionResolver: ExpressionResolver)
* resolution, after which they will be removed in the post-processing phase.
*/
override def resolve(unresolvedAlias: UnresolvedAlias): NamedExpression =
- scopes.top.lcaRegistry.withNewLcaScope {
+ scopes.current.lcaRegistry.withNewLcaScope {
val aliasWithResolvedChildren =
- withResolvedChildren(unresolvedAlias, expressionResolver.resolve)
+ withResolvedChildren(unresolvedAlias, expressionResolver.resolve _)
+ .asInstanceOf[UnresolvedAlias]
val resolvedAlias =
AliasResolution.resolve(aliasWithResolvedChildren).asInstanceOf[NamedExpression]
@@ -48,9 +49,7 @@ class AliasResolver(expressionResolver: ExpressionResolver)
s"unsupported expression: ${multiAlias.getClass.getName}"
)
case alias: Alias =>
- expressionResolver.getExpressionIdAssigner
- .mapExpression(alias)
- .asInstanceOf[Alias]
+ expressionResolver.getExpressionIdAssigner.mapExpression(alias)
}
}
@@ -59,11 +58,50 @@ class AliasResolver(expressionResolver: ExpressionResolver)
* resolve its children and afterwards reassign exprId to the resulting [[Alias]].
*/
def handleResolvedAlias(alias: Alias): Alias = {
- scopes.top.lcaRegistry.withNewLcaScope {
- val aliasWithResolvedChildren = withResolvedChildren(alias, expressionResolver.resolve)
- expressionResolver.getExpressionIdAssigner
- .mapExpression(aliasWithResolvedChildren)
- .asInstanceOf[Alias]
+ val resolvedAlias = scopes.current.lcaRegistry.withNewLcaScope {
+ val aliasWithResolvedChildren =
+ withResolvedChildren(alias, expressionResolver.resolve _).asInstanceOf[Alias]
+
+ expressionResolver.getExpressionIdAssigner.mapExpression(aliasWithResolvedChildren)
}
+
+ collapseAlias(resolvedAlias)
}
+
+ /**
+ * In case where there are two explicit [[Alias]]es, one on top of the other, remove the bottom
+ * one. For the example bellow:
+ *
+ * - df.select($"column".as("alias_1").as("alias_2"))
+ *
+ * the plan is:
+ *
+ * Project[
+ * Alias("alias_2")(
+ * Alias("alias_1")(id)
+ * )
+ * ]( ... )
+ *
+ * and after the `collapseAlias` call (removing the bottom one) it would be:
+ *
+ * Project[
+ * Alias("alias_2")(id)
+ * ]( ... )
+ */
+ private def collapseAlias(alias: Alias): Alias =
+ alias.child match {
+ case innerAlias: Alias =>
+ val metadata = if (alias.metadata.isEmpty) {
+ None
+ } else {
+ Some(alias.metadata)
+ }
+ alias.copy(child = innerAlias.child)(
+ exprId = alias.exprId,
+ qualifier = alias.qualifier,
+ explicitMetadata = metadata,
+ nonInheritableMetadataKeys = alias.nonInheritableMetadataKeys
+ )
+ case _ => alias
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AnalyzerBridgeState.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AnalyzerBridgeState.scala
index d3e93c82dfa21..a3fd6cf1cc874 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AnalyzerBridgeState.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AnalyzerBridgeState.scala
@@ -19,22 +19,37 @@ package org.apache.spark.sql.catalyst.analysis.resolver
import java.util.HashMap
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
/**
* The [[AnalyzerBridgeState]] is a state passed from legacy [[Analyzer]] to the single-pass
- * [[Resolver]].
+ * [[Resolver]]. It is used in dual-run mode (when
+ * [[ANALYZER_SINGLE_PASS_RESOLVER_RELATION_BRIDGING_ENABLED]] is true).
*
- * @param relationsWithResolvedMetadata A map from [[UnresolvedRelation]] to the relations with
+ * @param relationsWithResolvedMetadata A map from [[BridgedRelationId]] to the relations with
* resolved metadata. It allows us to reuse the relation metadata and avoid duplicate
- * catalog/table lookups in dual-run mode (when
- * [[ANALYZER_SINGLE_PASS_RESOLVER_RELATION_BRIDGING_ENABLED]] is true).
+ * catalog/table lookups.
+ * @param catalogRelationsWithResolvedMetadata A map from [[UnresolvedCatalogRelation]] to the
+ * relations with resolved metadata. It allows us to reuse the relation metadata and avoid
+ * duplicate catalog/table lookups.
*/
case class AnalyzerBridgeState(
relationsWithResolvedMetadata: AnalyzerBridgeState.RelationsWithResolvedMetadata =
- new AnalyzerBridgeState.RelationsWithResolvedMetadata)
+ new AnalyzerBridgeState.RelationsWithResolvedMetadata,
+ catalogRelationsWithResolvedMetadata: AnalyzerBridgeState.CatalogRelationsWithResolvedMetadata =
+ new AnalyzerBridgeState.CatalogRelationsWithResolvedMetadata
+) {
+ def addUnresolvedRelation(unresolvedRelation: UnresolvedRelation, relation: LogicalPlan): Unit = {
+ relationsWithResolvedMetadata.put(
+ BridgedRelationId(unresolvedRelation, AnalysisContext.get.catalogAndNamespace),
+ relation
+ )
+ }
+}
object AnalyzerBridgeState {
- type RelationsWithResolvedMetadata = HashMap[UnresolvedRelation, LogicalPlan]
+ type RelationsWithResolvedMetadata = HashMap[BridgedRelationId, LogicalPlan]
+ type CatalogRelationsWithResolvedMetadata = HashMap[UnresolvedCatalogRelation, LogicalPlan]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AttributeScopeStack.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AttributeScopeStack.scala
index 6f9d6defd2edb..b36a5ac6feb42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AttributeScopeStack.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AttributeScopeStack.scala
@@ -21,11 +21,15 @@ import java.util.ArrayDeque
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
+/**
+ * A scope with registered attributes encountered during the logical plan validation process. We
+ * use [[AttributeSet]] here to check the equality of attributes based on their expression IDs.
+ */
+case class AttributeScope(attributes: AttributeSet, isSubqueryRoot: Boolean = false)
+
/**
* The [[AttributeScopeStack]] is used to validate that the attribute which was encountered by the
- * [[ExpressionResolutionValidator]] is in the current operator's visibility scope. We use
- * [[AttributeSet]] as scope implementation here to check the equality of attributes based on their
- * expression IDs.
+ * [[ExpressionResolutionValidator]] is in the current operator's visibility scope.
*
* E.g. for the following SQL query:
* {{{
@@ -46,44 +50,69 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
* new ID #3 for an alias of `a + col2`.
*/
class AttributeScopeStack {
- private val stack = new ArrayDeque[AttributeSet]
- push()
+ private val stack = new ArrayDeque[AttributeScope]
+ stack.push(AttributeScope(attributes = AttributeSet(Seq.empty)))
/**
- * Get the relevant attribute scope in the context of the current operator.
+ * Check if the `attribute` is present in this stack. We check the current scope by default. If
+ * `isOuterReference` is true, we check the first scope above our subquery root.
*/
- def top: AttributeSet = {
- stack.peek()
+ def contains(attribute: Attribute, isOuterReference: Boolean = false): Boolean = {
+ if (!isOuterReference) {
+ current.attributes.contains(attribute)
+ } else {
+ outer match {
+ case Some(outer) => outer.attributes.contains(attribute)
+ case _ => false
+ }
+ }
}
/**
* Overwrite current relevant scope with a sequence of attributes which is an output of some
* operator. `attributes` can have duplicate IDs if the output of the operator contains multiple
- * occurrences of the same attribute.
+ * occurencies of the same attribute.
*/
- def overwriteTop(attributes: Seq[Attribute]): Unit = {
- stack.pop()
- stack.push(AttributeSet(attributes))
+ def overwriteCurrent(attributes: Seq[Attribute]): Unit = {
+ val current = stack.pop()
+
+ stack.push(current.copy(attributes = AttributeSet(attributes)))
}
/**
* Execute `body` in the context of a fresh attribute scope. Used by [[Project]] and [[Aggregate]]
* validation code since those operators introduce a new scope with fresh expression IDs.
*/
- def withNewScope[R](body: => R): Unit = {
- push()
+ def withNewScope[R](isSubqueryRoot: Boolean = false)(body: => R): Unit = {
+ stack.push(
+ AttributeScope(
+ attributes = AttributeSet(Seq.empty),
+ isSubqueryRoot = isSubqueryRoot
+ )
+ )
try {
body
} finally {
- pop()
+ stack.pop()
}
}
- private def push(): Unit = {
- stack.push(AttributeSet(Seq.empty))
- }
+ override def toString: String = stack.toString
+
+ private def current: AttributeScope = stack.peek
+
+ private def outer: Option[AttributeScope] = {
+ var outerScope: Option[AttributeScope] = None
+
+ val iter = stack.iterator
+ while (iter.hasNext && !outerScope.isDefined) {
+ val scope = iter.next
+
+ if (scope.isSubqueryRoot && iter.hasNext) {
+ outerScope = Some(iter.next)
+ }
+ }
- private def pop(): Unit = {
- stack.pop()
+ outerScope
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AutoGeneratedAliasProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AutoGeneratedAliasProvider.scala
new file mode 100644
index 0000000000000..2a49581b3499b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AutoGeneratedAliasProvider.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
+import org.apache.spark.sql.catalyst.util.toPrettySQL
+import org.apache.spark.sql.types.Metadata
+
+/**
+ * [[AutoGeneratedAliasProvider]] is a tool to create auto-generated aliases in the plan.
+ * All the auto-generated aliases have to be registered in [[ExpressionIdAssigner]].
+ */
+class AutoGeneratedAliasProvider(expressionIdAssigner: ExpressionIdAssigner) {
+
+ /**
+ * Create a new auto-generated [[Alias]]. If the `name` is not provided, [[toPrettySql]] is going
+ * to be used to generate a proper alias name. We call [[ExpressionIdAssigner.mapExpression]] to
+ * register this alias in the [[ExpressionIdAssigner]].
+ */
+ def newAlias(
+ child: Expression,
+ name: Option[String] = None,
+ explicitMetadata: Option[Metadata] = None): Alias = {
+ newAliasImpl(child = child, name = name, explicitMetadata = explicitMetadata)
+ }
+
+ /**
+ * Create a new auto-generated [[Alias]]. If the `name` is not provided, [[toPrettySql]] is going
+ * to be used to generate a proper alias name. We don't call
+ * [[ExpressionIdAssigner.mapExpression]] for outer aliases, because thy should be manually
+ * mapped in the context of an outer query.
+ */
+ def newOuterAlias(
+ child: Expression,
+ name: Option[String] = None,
+ explicitMetadata: Option[Metadata] = None): Alias = {
+ newAliasImpl(
+ child = child,
+ name = name,
+ explicitMetadata = explicitMetadata,
+ skipExpressionIdAssigner = true
+ )
+ }
+
+ def newAliasImpl(
+ child: Expression,
+ name: Option[String] = None,
+ explicitMetadata: Option[Metadata] = None,
+ skipExpressionIdAssigner: Boolean = false): Alias = {
+ var alias = Alias(
+ child = child,
+ name = name.getOrElse(toPrettySQL(child))
+ )(
+ explicitMetadata = explicitMetadata
+ )
+ if (skipExpressionIdAssigner) {
+ alias
+ } else {
+ expressionIdAssigner.mapExpression(alias)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala
index d2b586f3d372d..0f90e8a852fea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala
@@ -100,11 +100,14 @@ class BinaryArithmeticResolver(
override def resolve(unresolvedBinaryArithmetic: BinaryArithmetic): Expression = {
val binaryArithmeticWithResolvedChildren: BinaryArithmetic =
- withResolvedChildren(unresolvedBinaryArithmetic, expressionResolver.resolve)
+ withResolvedChildren(unresolvedBinaryArithmetic, expressionResolver.resolve _)
+ .asInstanceOf[BinaryArithmetic]
+
val binaryArithmeticWithResolvedSubtree: Expression =
withResolvedSubtree(binaryArithmeticWithResolvedChildren, expressionResolver.resolve) {
transformBinaryArithmeticNode(binaryArithmeticWithResolvedChildren)
}
+
timezoneAwareExpressionResolver.withResolvedTimezone(
binaryArithmeticWithResolvedSubtree,
conf.sessionLocalTimeZone
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationId.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationId.scala
new file mode 100644
index 0000000000000..02c42853fd4f1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationId.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+
+/**
+ * The [[BridgedRelationId]] is a unique identifier for an unresolved relation in the whole logical
+ * plan including all the nested views. It is used to lookup relations with resolved metadata which
+ * were processed by the fixed-point when running two Analyzers in dual-run mode. Storing
+ * [[catalogAndNamespace]] is required to differentiate tables/views created in different catalogs
+ * as their [[UnresolvedRelation]]s could have same structure.
+ */
+case class BridgedRelationId(
+ unresolvedRelation: UnresolvedRelation,
+ catalogAndNamespace: Seq[String]
+)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationMetadataProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationMetadataProvider.scala
index a33675e9dfd09..db75a3909c458 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationMetadataProvider.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationMetadataProvider.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.catalyst.analysis.resolver
import org.apache.spark.sql.catalyst.analysis.RelationResolution
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.connector.catalog.CatalogManager
/**
@@ -31,15 +32,17 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
class BridgedRelationMetadataProvider(
override val catalogManager: CatalogManager,
override val relationResolution: RelationResolution,
- analyzerBridgeState: AnalyzerBridgeState
+ analyzerBridgeState: AnalyzerBridgeState,
+ viewResolver: ViewResolver
) extends RelationMetadataProvider {
override val relationsWithResolvedMetadata = new RelationsWithResolvedMetadata
- updateRelationsWithResolvedMetadata()
/**
* We update relations on each [[resolve]] call, because relation IDs might have changed.
* This can happen for the nested views, since catalog name may differ, and expanded table name
- * will differ for the same [[UnresolvedRelation]].
+ * will differ for the same [[UnresolvedRelation]]. In order to overcome this issue, we use
+ * [[viewResolver]]'s context to peek into the most recent context and to only resolve the
+ * relations which were created under this same context.
*
* See [[ViewResolver.resolve]] for more info on how SQL configs are propagated to nested views).
*/
@@ -49,12 +52,37 @@ class BridgedRelationMetadataProvider(
private def updateRelationsWithResolvedMetadata(): Unit = {
analyzerBridgeState.relationsWithResolvedMetadata.forEach(
- (unresolvedRelation, relationWithResolvedMetadata) => {
- relationsWithResolvedMetadata.put(
- relationIdFromUnresolvedRelation(unresolvedRelation),
- relationWithResolvedMetadata
- )
+ (bridgeRelationId, relationWithResolvedMetadata) => {
+ if (viewResolver.getCatalogAndNamespace.getOrElse(Seq.empty)
+ == bridgeRelationId.catalogAndNamespace) {
+ relationsWithResolvedMetadata.put(
+ relationIdFromUnresolvedRelation(bridgeRelationId.unresolvedRelation),
+ tryConvertUnresolvedCatalogRelation(relationWithResolvedMetadata)
+ )
+ }
}
)
}
+
+ private def tryConvertUnresolvedCatalogRelation(source: LogicalPlan): LogicalPlan = {
+ source match {
+ case unresolvedCatalogRelation: UnresolvedCatalogRelation
+ if analyzerBridgeState.catalogRelationsWithResolvedMetadata
+ .containsKey(unresolvedCatalogRelation) =>
+ analyzerBridgeState.catalogRelationsWithResolvedMetadata.get(unresolvedCatalogRelation)
+
+ case SubqueryAlias(id, unresolvedCatalogRelation: UnresolvedCatalogRelation)
+ if analyzerBridgeState.catalogRelationsWithResolvedMetadata
+ .containsKey(unresolvedCatalogRelation) =>
+ SubqueryAlias(
+ id,
+ analyzerBridgeState.catalogRelationsWithResolvedMetadata.get(
+ unresolvedCatalogRelation
+ )
+ )
+
+ case _ =>
+ source
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala
index 4532965c6c684..428662d14dc85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala
@@ -40,7 +40,7 @@ class ConditionalExpressionResolver(
override def resolve(unresolvedConditionalExpression: ConditionalExpression): Expression = {
val conditionalExpressionWithResolvedChildren =
- withResolvedChildren(unresolvedConditionalExpression, expressionResolver.resolve)
+ withResolvedChildren(unresolvedConditionalExpression, expressionResolver.resolve _)
typeCoercionResolver.resolve(conditionalExpressionWithResolvedChildren)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CreateNamedStructResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CreateNamedStructResolver.scala
index 12c3c71b5e8be..d0e4ecea25cb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CreateNamedStructResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CreateNamedStructResolver.scala
@@ -31,6 +31,7 @@ class CreateNamedStructResolver(expressionResolver: ExpressionResolver)
override def resolve(createNamedStruct: CreateNamedStruct): Expression = {
val createNamedStructWithResolvedChildren =
withResolvedChildren(createNamedStruct, expressionResolver.resolve)
+ .asInstanceOf[CreateNamedStruct]
CreateNamedStructResolver.cleanupAliases(createNamedStructWithResolvedChildren)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CteScope.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CteScope.scala
index 4f61a25a3cd42..b0a39706ec225 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CteScope.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CteScope.scala
@@ -21,7 +21,13 @@ import java.util.{ArrayDeque, ArrayList}
import scala.jdk.CollectionConverters._
-import org.apache.spark.sql.catalyst.plans.logical.CTERelationDef
+import org.apache.spark.sql.catalyst.plans.logical.{
+ CTERelationDef,
+ LogicalPlan,
+ UnresolvedWith,
+ WithCTE
+}
+import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_WITH
/**
* The [[CteScope]] is responsible for keeping track of visible and known CTE definitions at a given
@@ -98,6 +104,67 @@ import org.apache.spark.sql.catalyst.plans.logical.CTERelationDef
* : +- ...
* }}}
*
+ * - The [[WithCTE]] operator is placed on top of the resolved operator if one of the following
+ * conditions are met:
+ * 1. We just resolved an [[UnresolvedWith]], which is the topmost [[UnresolvedWith]] of this
+ * root query, view or an expression subquery.
+ * 2. In case there is no single topmost [[UnresolvedWith]], we pick the least common ancestor
+ * of those branches. This is going to be a multi-child operator - [[Union]], [[Join]], etc.
+ *
+ * Here's an example for the second case:
+ *
+ * {{{
+ * SELECT * FROM (
+ * WITH cte AS (
+ * SELECT 1
+ * )
+ * SELECT * FROM cte
+ * UNION ALL
+ * (
+ * WITH cte AS (
+ * SELECT 2
+ * )
+ * SELECT * FROM cte
+ * )
+ * )
+ * }}}
+ *
+ * ->
+ *
+ * {{{
+ * Project [1#60]
+ * +- SubqueryAlias __auto_generated_subquery_name
+ * +- WithCTE
+ * :- CTERelationDef 30, false
+ * : +- ...
+ * :- CTERelationDef 31, false
+ * : +- ...
+ * +- Union false, false
+ * :- Project [1#60]
+ * : +- ...
+ * +- Project [2#61]
+ * +- ...
+ * }}}
+ *
+ * Consider a different example though:
+ *
+ * {{{
+ * SELECT * FROM (
+ * SELECT 1
+ * UNION ALL
+ * (
+ * WITH cte AS (
+ * SELECT 2
+ * )
+ * SELECT * FROM cte
+ * )
+ * )
+ * }}}
+ *
+ * The [[Union]] operator is not the least common ancestor of the [[UnresolvedWith]]s in the
+ * query. In fact, there's just a single [[UnresolvedWith]], which is a proper place where we
+ * need to place a [[WithCTE]].
+ *
* - However, if we have any expression subquery (scalar/IN/EXISTS...), the top
* [[CTERelationDef]]s and subquery's [[CTERelationDef]] won't be merged together (as they are
* separated by an expression tree):
@@ -167,6 +234,19 @@ class CteScope(val isRoot: Boolean, val isOpaque: Boolean) {
*/
private val visibleCtes = new IdentifierMap[CTERelationDef]
+ /**
+ * Optionally put [[WithCTE]] on top of the `resolvedOperator`. This is done just for the root
+ * scopes in the context of a correct `unresolvedOperator`. Return the `resolvedOperator`
+ * otherwise.
+ */
+ def tryPutWithCTE(unresolvedOperator: LogicalPlan, resolvedOperator: LogicalPlan): LogicalPlan = {
+ if (!knownCtes.isEmpty && isRoot && isSuitableOperatorForWithCTE(unresolvedOperator)) {
+ WithCTE(resolvedOperator, knownCtes.asScala.toSeq)
+ } else {
+ resolvedOperator
+ }
+ }
+
/**
* Register a new CTE definition in this scope. Since the scope is created per single WITH clause,
* there can be no name conflicts, but this is validated by the Parser in [[AstBuilder]]
@@ -197,11 +277,17 @@ class CteScope(val isRoot: Boolean, val isOpaque: Boolean) {
}
/**
- * Get all known (from this and child scopes) [[CTERelationDef]]s. This is used to construct
- * [[WithCTE]] from a root scope.
+ * This predicate returns `true` if the `unresolvedOperator` is suitable to place a [[WithCTE]]
+ * on top of its resolved counterpart. This is the case for:
+ * - [[UnresolvedWith]];
+ * - Multi-child operators with [[UnresolvedWith]]s in multiple subtrees.
*/
- def getKnownCtes: Seq[CTERelationDef] = {
- knownCtes.asScala.toSeq
+ private def isSuitableOperatorForWithCTE(unresolvedOperator: LogicalPlan): Boolean = {
+ unresolvedOperator match {
+ case _: UnresolvedWith => true
+ case _ =>
+ CteRegistry.isSuitableMultiChildOperatorForWithCTE(unresolvedOperator)
+ }
}
}
@@ -215,6 +301,29 @@ class CteRegistry {
def currentScope: CteScope = stack.peek()
+ /**
+ * This is a [[withNewScope]] variant specifically designed to be called above multi-child
+ * operator children resolution (e.g. for children of a [[Join]] or [[Union]]).
+ *
+ * The `isRoot` flag has to be propagated from the parent scope if all of the following
+ * conditions are met:
+ * - The current scope is a root scope
+ * - The multi-child `unresolvedOperator` IS NOT suitable to place a [[WithCTE]]
+ * - Some operator in `unresolvedChild` subtree IS suitable to place a [[WithCTE]].
+ */
+ def withNewScopeUnderMultiChildOperator[R](
+ unresolvedOperator: LogicalPlan,
+ unresolvedChild: LogicalPlan
+ )(body: => R): R = {
+ withNewScope(
+ isRoot = currentScope.isRoot &&
+ !CteRegistry.isSuitableMultiChildOperatorForWithCTE(unresolvedOperator) &&
+ CteRegistry.hasSuitableOperatorForWithCTEInSubtree(unresolvedChild)
+ ) {
+ body
+ }
+ }
+
/**
* A RAII-wrapper for pushing/popping scopes. This is used by the [[Resolver]] to create a new
* scope for each WITH clause.
@@ -255,3 +364,19 @@ class CteRegistry {
result
}
}
+
+object CteRegistry {
+
+ /**
+ * This predicate returns `true` if the `unresolvedOperator` is a multi-child operator that
+ * contains multiple [[UnresolvedWith]] operators in its subtrees. This way we determine if
+ * this operator is suitable to place a [[WithCTE]] on top of its resolved counterpart.
+ */
+ def isSuitableMultiChildOperatorForWithCTE(unresolvedOperator: LogicalPlan): Boolean = {
+ unresolvedOperator.children.count(hasSuitableOperatorForWithCTEInSubtree(_)) > 1
+ }
+
+ def hasSuitableOperatorForWithCTEInSubtree(unresolvedOperator: LogicalPlan): Boolean = {
+ unresolvedOperator.containsPattern(UNRESOLVED_WITH)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala
index e9de72e6fa36a..acac8edd4a73c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala
@@ -51,7 +51,6 @@ trait DelegatesResolutionToExtensions {
matchedExtension match {
case None =>
resolutionResult = extension.resolveOperator(unresolvedOperator, resolver)
-
if (resolutionResult.isDefined) {
matchedExtension = Some(extension)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala
index c9b8dece77cec..475f38cabcfff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala
@@ -21,7 +21,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver
* This is an addon to [[ResolverGuard]] functionality for features that cannot be determined by
* only looking at the unresolved plan. [[Resolver]] will throw this control-flow exception
* when it encounters some explicitly unsupported feature. Later behavior depends on the value of
- * [[HybridAnalyzer.checkSupportedSinglePassFeatures]] flag:
+ * [[HybridAnalyzer.exposeExplicitlyUnsupportedResolverFeature]] flag:
* - If it is true: It will later be caught by [[HybridAnalyzer]] to abort single-pass
* analysis without comparing single-pass and fixed-point results. The motivation for this
* feature is the same as for the [[ResolverGuard]] - we want to have an explicit allowlist of
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala
index 2d8c332d8527b..cbd772d5a6781 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{
ExprId,
NamedExpression
}
+import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, LeafNode}
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.errors.QueryCompilationErrors
/**
@@ -35,11 +37,15 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
* that Optimizer performs its work correctly and does not produce correctness issues.
*
* The framework works the following way:
- * - Each leaf operator must have unique output IDs (even if it's the same table, view, or CTE).
+ * - Each leaf operator must have globally unique output IDs (even if it's the same table, view,
+ * or CTE).
* - The [[AttributeReference]]s get propagated "upwards" through the operator tree with their IDs
- * preserved.
- * - Each [[Alias]] gets assigned a new unique ID and it sticks with it after it gets converted to
- * an [[AttributeReference]] when it is outputted from the operator that produced it.
+ * preserved. In case of correlated subqueries [[AttributeReference]]s may propagate downwards
+ * from the outer scope to the point of correlated reference in the subquery. Currently only
+ * one level of correlation is supported.
+ * - Each [[Alias]] gets assigned a new globally unique ID and it sticks with it after it gets
+ * converted to an [[AttributeReference]] when it is outputted from the operator that produced
+ * it.
* - Any operator may have [[AttributeReference]]s with the same IDs in its output given it is the
* same attribute.
* Thus, **no multi-child operator may have children with conflicting [[AttributeReference]] IDs**.
@@ -116,9 +122,48 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
* Because the latter case will confuse the Optimizer and the top [[Project]] will be eliminated
* leading to incorrect result.
*
- * There's an important caveat here: the leftmost branch of a logical plan tree. In this branch we
- * need to preserve the expression IDs wherever possible because DataFrames may reference each other
- * using their attributes. This also makes sense for performance reasons.
+ * In case of partially resolved DataFrame subtrees with correlated subqueries inside we need to
+ * remap [[OuterReference]]s as well:
+ *
+ * {{{
+ * val df = spark.sql("""
+ * SELECT * FROM t1 WHERE EXISTS (
+ * SELECT * FROM t2 WHERE t2.id == t1.id
+ * )
+ * """)
+ * df.union(df)
+ * }}}
+ *
+ * The analyzed plan should be:
+ * {{{
+ * Union false, false
+ * :- Project [id#1]
+ * : +- Filter exists#9 [id#1]
+ * : : +- Project [id#16]
+ * : : +- Filter (id#16 = outer(id#1))
+ * : : +- SubqueryAlias spark_catalog.default.t2
+ * : : +- Relation spark_catalog.default.t2[id#16] parquet
+ * : +- SubqueryAlias spark_catalog.default.t1
+ * : +- Relation spark_catalog.default.t1[id#1] parquet
+ * +- Project [id#17 AS id#19]
+ * +- Project [id#17]
+ * +- Filter exists#9 [id#17]
+ * : +- Project [id#18]
+ * : +- Filter (id#18 = outer(id#17))
+ * : +- SubqueryAlias spark_catalog.default.t2
+ * : +- Relation spark_catalog.default.t2[id#18] parquet
+ * +- SubqueryAlias spark_catalog.default.t1
+ * +- Relation spark_catalog.default.t1[id#17] parquet
+ * }}}
+ *
+ * Note how id#17 is the same in outer branch and in a subquery - is was properly remapped, because
+ * the right subtree of [[Union]] contained identical expression IDs as the left subtree. That's
+ * why we pass main mapping as outer mapping to the correlated subquery branch.
+ *
+ * There's an important caveat here: those branches of a logical plan tree where outputs do not
+ * conflict. We should preserve expression IDs on those branches wherever possible because
+ * DataFrames may reference each other using their attributes. This also makes sense for
+ * performance reasons.
*
* Consider this example:
*
@@ -130,10 +175,27 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
*
* In this example `df("id")` references lower `id` attribute by expression ID, so `union` must not
* reassign expression IDs in `df1` (left child). Referencing `df2` (right child) is not supported
- * in Spark.
+ * in Spark, because [[Union]] does not output it, but we don't have to regenerate expression IDs
+ * in that branch either.
+ *
+ * However:
+ *
+ * {{{
+ * val df1 = spark.range(0, 10).select($"id")
+ * df1.union(df1).filter(df1("id") === 5)
+ * }}}
+ *
+ * Here we need to regenerate expression IDs in the right branch, because those would conflict
+ * (both branches are the same plan). Expression IDs in the left branch may be preserved.
+ *
+ * CTE references are handled in a special way to stay compatible with the fixed-point Analyzer.
+ * First [[CTERelationRef]] that we meet in the query plan can preserve its output expression IDs,
+ * and the plan will be inlined by the [[InlineCTE]] without any artificial [[Alias]]es that
+ * "stitch" expression IDs together. This way we ensure that Optimizer behavior is the same as
+ * after the fixed-point Analyzer and that no extra projections are introduced.
*
* The [[ExpressionIdAssigner]] covers both SQL and DataFrame scenarios with single approach and is
- * integrated in the single-pass analysis framework.
+ * integrated into the single-pass analysis framework.
*
* The [[ExpressionIdAssigner]] is used in the following way:
* - When the [[Resolver]] traverses the tree downwards prior to starting bottom-up analysis,
@@ -144,88 +206,171 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
* the mapping needs to be initialized later with the correct output of a resolved operator.
* - When the bottom-up analysis starts, we assign IDs to all the [[NamedExpression]]s which are
* present in operators starting from the [[LeafNode]]s using [[mapExpression]].
- * [[createMapping]] is called right after each [[LeafNode]] is resolved, and first remapped
- * attributes come from that [[LeafNode]]. This is done in [[Resolver.handleLeafOperator]] for
- * each logical plan tree branch except the leftmost.
+ * [[createMappingForLeafOperator]] is called right after each [[LeafNode]] is resolved, and
+ * first remapped attributes come from that [[LeafNode]]. This is done if leaf operator output
+ * doesn't conflict with `globalExpressionIds`.
* - Once the child branch is resolved, [[withNewMapping]] ends by calling [[mappingStack.pop]].
- * - After the multi-child operator is resolved, we call [[createMapping]] to
- * initialize the mapping with attributes _chosen_ (e.g. [[Union.mergeChildOutputs]]) by that
- * operator's resolution algorithm and remap _old_ expression IDs to those chosen attributes.
+ * - After the multi-child operator is resolved, we call [[createMappingFromChildMappings]] to
+ * initialize the mapping with attributes collected in [[withNewMapping]] with
+ * `collectChildMapping = true`.
+ * - While traversing the expression tree, we may meet a [[SubqueryExpression]] and resolve its
+ * plan. In this case we call [[withNewMapping]] with `isSubqueryRoot = true` to pass the
+ * current mapping as outer mapping to the subquery branches. Any subquery branch may reference
+ * outer attributes, so if `isSubqueryRoot` is `false`, we pass the previous `outerMapping` to
+ * lower branches. Since we only support one level of correlation, for every subquery level
+ * current `mapping` becomes `outerMapping` for the next level.
* - Continue remapping expressions until we reach the root of the operator tree.
*/
class ExpressionIdAssigner {
- private val mappingStack = new ExpressionIdAssigner.Stack
- mappingStack.push(ExpressionIdAssigner.StackEntry(isLeftmostBranch = true))
+ private val globalExpressionIds = new HashSet[ExprId]
+ private val cteRelationRefOutputIds = new HashSet[ExprId]
- /**
- * Returns `true` if the current logical plan branch is the leftmost branch. This is important
- * in the context of preserving expression IDs in DataFrames. See class doc for more details.
- */
- def isLeftmostBranch: Boolean = mappingStack.peek().isLeftmostBranch
+ private val mappingStack = new ExpressionIdAssigner.Stack
+ mappingStack.push(ExpressionIdAssigner.StackEntry())
/**
* A RAII-wrapper for [[mappingStack.push]] and [[mappingStack.pop]]. [[Resolver]] uses this for
* every child of a multi-child operator to ensure that each operator branch uses an isolated
* expression ID mapping.
*
- * @param isLeftmostChild whether the current child is the leftmost child of the operator that is
- * being resolved. This is used to determine whether the new stack entry is gonna be in the
- * leftmost logical plan branch. It's `false` by default, because it's safer to remap attributes
- * than to leave duplicates (to prevent correctness issues).
+ * @param isSubqueryRoot whether the new branch is related to a subquery root. In this case we
+ * pass current `mapping` as `outerMapping` to the subquery branches. Otherwise we just
+ * propagate `outerMapping` itself, because any nested subquery operator may reference outer
+ * attributes.
+ * @param collectChildMapping whether to collect a child mapping into the current stack entry.
+ * This is used in multi-child operators to automatically propagate mapped expression IDs
+ * upwards using [[createMappingFromChildMappings]].
*/
- def withNewMapping[R](isLeftmostChild: Boolean = false)(body: => R): R = {
+ def withNewMapping[R](
+ isSubqueryRoot: Boolean = false,
+ collectChildMapping: Boolean = false
+ )(body: => R): R = {
+ val currentStackEntry = mappingStack.peek()
+
mappingStack.push(
ExpressionIdAssigner.StackEntry(
- isLeftmostBranch = isLeftmostChild && isLeftmostBranch
+ outerMapping = if (isSubqueryRoot) {
+ currentStackEntry.mapping.map(new ExpressionIdAssigner.Mapping(_))
+ } else {
+ currentStackEntry.outerMapping
+ }
)
)
+
try {
- body
+ val result = body
+
+ val childStackEntry = mappingStack.peek()
+ if (collectChildMapping) {
+ childStackEntry.mapping match {
+ case Some(childMapping) =>
+ currentStackEntry.childMappings.push(childMapping)
+ case None =>
+ throw SparkException.internalError("Child mapping doesn't exist")
+ }
+ }
+
+ result
} finally {
mappingStack.pop()
}
}
/**
- * Create mapping with the given `newOutput` that rewrites the `oldOutput`. This
- * is used by the [[Resolver]] after the multi-child operator is resolved to fill the current
- * mapping with the attributes _chosen_ by that operator's resolution algorithm and remap _old_
- * expression IDs to those chosen attributes. It's also used by the [[ExpressionResolver]] right
- * before remapping the attributes of a [[LeafNode]].
+ * Create mapping for the given `newOperator` that replaces the `oldOperator`. This is used by
+ * the [[Resolver]] after a certain [[LeafNode]] is resolved to make sure that leaf node outputs
+ * in the query don't have conflicting expression IDs.
*
- * `oldOutput` is present for already resolved subtrees (e.g. DataFrames), but for SQL queries
+ * `oldOperator` is present for already resolved subtrees (e.g. DataFrames), but for SQL queries
* is will be `None`, because that logical plan is analyzed for the first time.
*/
- def createMapping(
- newOutput: Seq[Attribute] = Seq.empty,
- oldOutput: Option[Seq[Attribute]] = None): Unit = {
+ def createMappingForLeafOperator(
+ newOperator: LeafNode,
+ oldOperator: Option[LeafNode] = None): Unit = {
if (mappingStack.peek().mapping.isDefined) {
throw SparkException.internalError(
- s"Attempt to overwrite existing mapping. New output: $newOutput, old output: $oldOutput"
+ "Attempt to overwrite existing mapping. " +
+ s"New operator: $newOperator, old operator: $oldOperator"
)
}
val newMapping = new ExpressionIdAssigner.Mapping
- oldOutput match {
- case Some(oldOutput) =>
- if (newOutput.length != oldOutput.length) {
+ oldOperator match {
+ case Some(oldOperator) =>
+ if (newOperator.output.length != oldOperator.output.length) {
throw SparkException.internalError(
- s"Outputs have different lengths. New output: $newOutput, old output: $oldOutput"
+ "Outputs have different lengths. " +
+ s"New operator: $newOperator, old operator: $oldOperator"
)
}
- newOutput.zip(oldOutput).foreach {
+ newOperator.output.zip(oldOperator.output).foreach {
case (newAttribute, oldAttribute) =>
newMapping.put(oldAttribute.exprId, newAttribute.exprId)
newMapping.put(newAttribute.exprId, newAttribute.exprId)
+
+ registerLeafOperatorAttribute(newOperator, newAttribute)
}
case None =>
- newOutput.foreach { newAttribute =>
+ newOperator.output.foreach { newAttribute =>
newMapping.put(newAttribute.exprId, newAttribute.exprId)
+
+ registerLeafOperatorAttribute(newOperator, newAttribute)
}
}
- mappingStack.push(mappingStack.pop().copy(mapping = Some(newMapping)))
+ setCurrentMapping(newMapping)
+ }
+
+ /**
+ * Create new mapping in current scope based on collected child mappings. The calling code
+ * must pass `collectChildMapping = true` to all the [[withNewMapping]] calls beforehand.
+ *
+ * Since nodes are resolved from left to right (the Analyzer is guaranteed to resolve left
+ * branches first), we know that by calling [[childMappings.pop]] we get the mappings from right
+ * to left. This approach leads to duplicate expression IDs from right mapping keys being
+ * overwritten by the left ones. This order is very important, because in case of duplicate
+ * DataFrame subtrees like self-joins, expression IDs from right duplicate branch cannot be
+ * accessed:
+ *
+ * {{{
+ * val df1 = spark.range(0, 10)
+ * val df2 = df1.select(($"id" + 1).as("id"))
+ *
+ * // Both branches originate from the same `df1`, and have duplicate IDs, so right branch IDs
+ * // are regenedated. Thus, it's important to prioritize left mapping values for the same keys.
+ * val df3 = df2.join(df1, "id")
+ *
+ * // This succeeds because left branch IDs are preserved.
+ * df3.where(df2("id") === 1)
+ *
+ * // This fails because right branch IDs are regenerated.
+ * df3.where(df1("id") === 1)
+ * }}}
+ *
+ * This is used by multi child operators like [[Join]] or [[Union]] to propagate mapped
+ * expression IDs upwards.
+ */
+ def createMappingFromChildMappings(): Unit = {
+ if (mappingStack.peek().mapping.isDefined) {
+ throw SparkException.internalError(
+ "Attempt to overwrite existing mapping with child mappings"
+ )
+ }
+
+ val currentStackEntry = mappingStack.peek()
+ if (currentStackEntry.childMappings.isEmpty) {
+ throw SparkException.internalError("No child mappings to create new current mapping")
+ }
+
+ val newMapping = new ExpressionIdAssigner.Mapping
+ while (!currentStackEntry.childMappings.isEmpty) {
+ val nextMapping = currentStackEntry.childMappings.pop()
+
+ newMapping.putAll(nextMapping)
+ }
+
+ setCurrentMapping(newMapping)
}
/**
@@ -233,9 +378,9 @@ class ExpressionIdAssigner {
* expression, or return a corresponding new instance of the same attribute, that was previously
* reallocated and is present in the current [[mappingStack]] entry.
*
- * For [[Alias]]es: Try to preserve them if we are in the leftmost logical plan tree branch and
- * unless they conflict. Conflicting [[Alias]] IDs are never acceptable. Otherwise, reallocate
- * with a new ID and return that instance.
+ * For [[Alias]]es: Try to preserve it if the alias ID doesn't conflict with
+ * `globalExpressionIds`. Conflicting [[Alias]] IDs are never acceptable.
+ * Otherwise, reallocate with a new ID and return that instance.
*
* For [[AttributeReference]]s: If the attribute is present in the current [[mappingStack]] entry,
* return that instance, otherwise reallocate with a new ID and return that instance. The mapping
@@ -267,32 +412,39 @@ class ExpressionIdAssigner {
* spark.sql("SELECT col1 FROM VALUES (1)").select(col("col1").as("a", metadata1)).to(schema)
* }}}
*/
- def mapExpression(originalExpression: NamedExpression): NamedExpression = {
+ def mapExpression[NamedExpressionType <: NamedExpression](
+ originalExpression: NamedExpressionType): NamedExpressionType = {
if (mappingStack.peek().mapping.isEmpty) {
throw SparkException.internalError(
- "Expression ID mapping doesn't exist. Please call createMapping(...) first. " +
- s"Original expression: $originalExpression"
+ "Expression ID mapping doesn't exist. Please first call " +
+ "createMappingForLeafOperator(...) for leaf nodes or createMappingFromChildMappings(...) " +
+ s"for multi-child nodes. Original expression: $originalExpression"
)
}
val currentMapping = mappingStack.peek().mapping.get
val resultExpression = originalExpression match {
- case alias: Alias if isLeftmostBranch =>
- val resultAlias = currentMapping.get(alias.exprId) match {
- case null =>
- alias
- case _ =>
- alias.newInstance()
+ case alias: Alias =>
+ val resultAlias = if (globalExpressionIds.contains(alias.exprId)) {
+ val newAlias = newAliasInstance(alias)
+ currentMapping.put(alias.exprId, newAlias.exprId)
+ newAlias
+ } else {
+ alias
}
+
currentMapping.put(resultAlias.exprId, resultAlias.exprId)
+
+ globalExpressionIds.add(resultAlias.exprId)
+
resultAlias
- case alias: Alias =>
- reassignExpressionId(alias, currentMapping)
case attributeReference: AttributeReference =>
currentMapping.get(attributeReference.exprId) match {
case null =>
- reassignExpressionId(attributeReference, currentMapping)
+ throw SparkException.internalError(
+ s"Encountered a dangling attribute reference $attributeReference"
+ )
case mappedExpressionId =>
attributeReference.withExprId(mappedExpressionId)
}
@@ -302,31 +454,96 @@ class ExpressionIdAssigner {
)
}
- resultExpression.copyTagsFrom(originalExpression)
- resultExpression
+ resultExpression.asInstanceOf[NamedExpressionType]
+ }
+
+ /**
+ * Map [[AttributeReference]] which is a child of [[OuterReference]]. When [[ExpressionResolver]]
+ * meets an attribute under a resolved [[OuterReference]], it remaps it using the outer
+ * mapping passed from the parent plan of the [[SubqueryExpression]] that is currently being
+ * re-analyzed. This mapping must exist, as well as a mapped expression ID. Otherwise we met a
+ * danging outer reference, which is an internal error.
+ */
+ def mapOuterReference(attributeReference: AttributeReference): AttributeReference = {
+ if (mappingStack.peek().outerMapping.isEmpty) {
+ throw SparkException.internalError(
+ "Outer expression ID mapping doesn't exist while remapping outer reference " +
+ s"$attributeReference"
+ )
+ }
+
+ mappingStack.peek().outerMapping.get.get(attributeReference.exprId) match {
+ case null =>
+ throw SparkException.internalError(
+ s"No mapped expression ID for outer reference $attributeReference"
+ )
+ case mappedExpressionId =>
+ attributeReference.withExprId(mappedExpressionId)
+ }
+ }
+
+ /**
+ * Returns `true` if expression IDs for the current [[LeafNode]] should be preserved. This is
+ * important for DataFrames that reference columns by their IDs. See class doc for more details.
+ *
+ * Expression IDs of outputs of the first CTE reference are not regenerated for compatibility
+ * with the fixed-point Analyzer.
+ */
+ def shouldPreserveLeafOperatorIds(leafOperator: LeafNode): Boolean = {
+ leafOperator match {
+ case cteRelationRef: CTERelationRef =>
+ cteRelationRef.output.forall { attribute =>
+ !cteRelationRefOutputIds.contains(attribute.exprId)
+ }
+ case _ =>
+ leafOperator.output.forall { attribute =>
+ !globalExpressionIds.contains(attribute.exprId)
+ }
+ }
}
- private def reassignExpressionId(
- originalExpression: NamedExpression,
- currentMapping: ExpressionIdAssigner.Mapping): NamedExpression = {
- val newExpression = originalExpression.newInstance()
+ private def setCurrentMapping(mapping: ExpressionIdAssigner.Mapping): Unit = {
+ val currentEntry = mappingStack.pop()
+ mappingStack.push(currentEntry.copy(mapping = Some(mapping)))
+ }
- currentMapping.put(originalExpression.exprId, newExpression.exprId)
- currentMapping.put(newExpression.exprId, newExpression.exprId)
+ private def newAliasInstance(alias: Alias): Alias = {
+ val newAlias = withOrigin(alias.origin) {
+ alias.newInstance().asInstanceOf[Alias]
+ }
+ newAlias.copyTagsFrom(alias)
+ newAlias
+ }
- newExpression
+ private def registerLeafOperatorAttribute(leafOperator: LeafNode, attribute: Attribute): Unit = {
+ globalExpressionIds.add(attribute.exprId)
+ if (leafOperator.isInstanceOf[CTERelationRef]) {
+ cteRelationRefOutputIds.add(attribute.exprId)
+ }
}
}
object ExpressionIdAssigner {
type Mapping = HashMap[ExprId, ExprId]
- case class StackEntry(mapping: Option[Mapping] = None, isLeftmostBranch: Boolean = false)
+ case class StackEntry(
+ mapping: Option[Mapping] = None,
+ outerMapping: Option[Mapping] = None,
+ childMappings: ArrayDeque[Mapping] = new ArrayDeque[Mapping])
type Stack = ArrayDeque[StackEntry]
/**
- * Assert that `outputs` don't have conflicting expression IDs. This is only relevant for child
+ * Assert that `outputs` don't have conflicting expression IDs.
+ */
+ def assertOutputsHaveNoConflictingExpressionIds(outputs: Seq[Seq[Attribute]]): Unit = {
+ if (doOutputsHaveConflictingExpressionIds(outputs)) {
+ throw SparkException.internalError(s"Conflicting expression IDs in child outputs: $outputs")
+ }
+ }
+
+ /**
+ * Check whether `outputs` have conflicting expression IDs. This is only relevant for child
* outputs of multi-child operators. Conflicting attributes are only checked between different
* child branches, since one branch may output the same attribute multiple times. Hence, we use
* only distinct expression IDs from each output.
@@ -344,14 +561,11 @@ object ExpressionIdAssigner {
* SELECT col1 FROM t1
* ;
* }}}
+ *
+ * One edge case is [[WithCTE]] - we don't have to check conflicts between [[CTERelationDef]]s and
+ * the plan itself.
*/
- def assertOutputsHaveNoConflictingExpressionIds(outputs: Seq[Seq[Attribute]]): Unit = {
- if (doOutputsHaveConflictingExpressionIds(outputs)) {
- throw SparkException.internalError(s"Conflicting expression IDs in child outputs: $outputs")
- }
- }
-
- private def doOutputsHaveConflictingExpressionIds(outputs: Seq[Seq[Attribute]]): Boolean = {
+ def doOutputsHaveConflictingExpressionIds(outputs: Seq[Seq[Attribute]]): Boolean = {
outputs.length > 1 && {
val expressionIds = new HashSet[ExprId]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala
index 822cda2289621..66fa5a4226e63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala
@@ -23,21 +23,66 @@ package org.apache.spark.sql.catalyst.analysis.resolver
* of [[ExpressionResolver.resolve]] call, which are not the resolved child itself, from children
* to parents.
*
- * @param hasAggregateExpressionsInASubtree A flag that highlights that a specific node
- * corresponding to [[ExpressionResolutionContext]] has
- * aggregate expressions in its subtree.
- * @param hasAttributeInASubtree A flag that highlights that a specific node corresponding to
- * [[ExpressionResolutionContext]] has attributes in its subtree.
+ * @param isRoot A flag indicating that we are resolving the root of the expression tree. It's
+ * used by the [[ExpressionResolver]] to correctly propagate top-level information like
+ * [[isTopOfProjectList]]. It's going to be set to `true` for the top-level context when we are
+ * entering expression resolution from a specific operator (this can be either a top-level query
+ * or a subquery).
+ * @param hasLocalReferences A flag that highlights that a specific node corresponding to
+ * [[ExpressionResolutionContext]] has local [[AttributeReferences]] in its subtree.
+ * @param hasOuterReferences A flag that highlights that a specific node corresponding to
+ * [[ExpressionResolutionContext]] has outer [[AttributeReferences]] in its subtree.
+ * @param hasAggregateExpressions A flag that highlights that a specific node
+ * corresponding to [[ExpressionResolutionContext]] has aggregate expressions in its subtree.
+ * @param hasAttributeOutsideOfAggregateExpressions A flag that highlights that a specific node
+ * corresponding to [[ExpressionResolutionContext]] has attributes in its subtree which are not
+ * under an [[AggregateExpression]].
* @param hasLateralColumnAlias A flag that highlights that a specific node corresponding to
- * [[ExpressionResolutionContext]] has LCA in its subtree.
+ * [[ExpressionResolutionContext]] has LCA in its subtree.
+ * @param isTopOfProjectList A flag indicating that we are resolving top of [[Project]] list.
+ * Otherwise, extra [[Alias]]es have to be stripped away.
+ * @param resolvingGroupingExpressions A flag indicating whether an expression we are resolving is
+ * one of [[Aggregate.groupingExpressions]].
*/
class ExpressionResolutionContext(
- var hasAggregateExpressionsInASubtree: Boolean = false,
- var hasAttributeInASubtree: Boolean = false,
- var hasLateralColumnAlias: Boolean = false) {
- def merge(other: ExpressionResolutionContext): Unit = {
- hasAggregateExpressionsInASubtree |= other.hasAggregateExpressionsInASubtree
- hasAttributeInASubtree |= other.hasAttributeInASubtree
- hasLateralColumnAlias |= other.hasLateralColumnAlias
+ val isRoot: Boolean = false,
+ var hasLocalReferences: Boolean = false,
+ var hasOuterReferences: Boolean = false,
+ var hasAggregateExpressions: Boolean = false,
+ var hasAttributeOutsideOfAggregateExpressions: Boolean = false,
+ var hasLateralColumnAlias: Boolean = false,
+ var isTopOfProjectList: Boolean = false,
+ var resolvingGroupingExpressions: Boolean = false) {
+
+ /**
+ * Propagate generic information that is valid across the whole expression tree from the
+ * [[child]] context.
+ */
+ def mergeChild(child: ExpressionResolutionContext): Unit = {
+ hasLocalReferences |= child.hasLocalReferences
+ hasOuterReferences |= child.hasOuterReferences
+ hasAggregateExpressions |= child.hasAggregateExpressions
+ hasAttributeOutsideOfAggregateExpressions |= child.hasAttributeOutsideOfAggregateExpressions
+ hasLateralColumnAlias |= child.hasLateralColumnAlias
+ }
+}
+
+object ExpressionResolutionContext {
+
+ /**
+ * Create a new child [[ExpressionResolutionContext]]. Propagates the relevant information
+ * top-down.
+ */
+ def createChild(parent: ExpressionResolutionContext): ExpressionResolutionContext = {
+ if (parent.isRoot) {
+ new ExpressionResolutionContext(
+ isTopOfProjectList = parent.isTopOfProjectList,
+ resolvingGroupingExpressions = parent.resolvingGroupingExpressions
+ )
+ } else {
+ new ExpressionResolutionContext(
+ resolvingGroupingExpressions = parent.resolvingGroupingExpressions
+ )
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala
index 3ca62348e892e..e0508e924678a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala
@@ -22,10 +22,15 @@ import org.apache.spark.sql.catalyst.expressions.{
ArraysZip,
AttributeReference,
BinaryExpression,
+ Exists,
Expression,
+ InSubquery,
+ ListQuery,
Literal,
NamedExpression,
+ OuterReference,
Predicate,
+ ScalarSubquery,
TimeZoneAwareExpression
}
import org.apache.spark.sql.types.BooleanType
@@ -36,6 +41,12 @@ import org.apache.spark.sql.types.BooleanType
* logical plan. You can find more info in the [[ResolutionValidator]] scaladoc.
*/
class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) {
+ private val attributeScopeStack = resolutionValidator.getAttributeScopeStack
+
+ /**
+ * The flag to indicate if the validation process traverses the outer reference subtree.
+ */
+ private var inOuterReferenceSubtree = false
/**
* Validate resolved expression tree. The principle is the same as
@@ -51,6 +62,14 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) {
validateBinaryExpression(binaryExpression)
case literal: Literal =>
validateLiteral(literal)
+ case scalarSubquery: ScalarSubquery =>
+ validateScalarSubquery(scalarSubquery)
+ case inSubquery: InSubquery =>
+ validateInSubquery(inSubquery)
+ case listQuery: ListQuery =>
+ validateListQuery(listQuery)
+ case exists: Exists =>
+ validateExists(exists)
case predicate: Predicate =>
validatePredicate(predicate)
case arraysZip: ArraysZip =>
@@ -68,6 +87,8 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) {
validateAttributeReference(attributeReference)
case alias: Alias =>
validateAlias(alias)
+ case outerReference: OuterReference =>
+ validateOuterReference(outerReference)
}
}
@@ -82,9 +103,8 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) {
private def validateAttributeReference(attributeReference: AttributeReference): Unit = {
assert(
- resolutionValidator.attributeScopeStack.top.contains(attributeReference),
- s"Attribute $attributeReference is missing from attribute scope: " +
- s"${resolutionValidator.attributeScopeStack.top}"
+ attributeScopeStack.contains(attributeReference, isOuterReference = inOuterReferenceSubtree),
+ s"Attribute $attributeReference is missing from attribute scope stack: $attributeScopeStack"
)
}
@@ -92,6 +112,15 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) {
validate(alias.child)
}
+ private def validateOuterReference(outerReference: OuterReference): Unit = {
+ inOuterReferenceSubtree = true
+ try {
+ validate(outerReference.e)
+ } finally {
+ inOuterReferenceSubtree = false
+ }
+ }
+
private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = {
validate(binaryExpression.left)
validate(binaryExpression.right)
@@ -106,6 +135,49 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) {
private def validateLiteral(literal: Literal): Unit = {}
+ private def validateScalarSubquery(scalarSubquery: ScalarSubquery): Unit = {
+ attributeScopeStack.withNewScope(isSubqueryRoot = true) {
+ resolutionValidator.validate(scalarSubquery.plan)
+ }
+
+ for (outerAttribute <- scalarSubquery.outerAttrs) {
+ validate(outerAttribute)
+ }
+
+ assert(
+ scalarSubquery.plan.output.size == 1,
+ s"Scalar subquery returns more than one column: ${scalarSubquery.plan.output}"
+ )
+ }
+
+ private def validateInSubquery(inSubquery: InSubquery): Unit = {
+ for (value <- inSubquery.values) {
+ validate(value)
+ }
+
+ validate(inSubquery.query)
+ }
+
+ private def validateListQuery(listQuery: ListQuery): Unit = {
+ attributeScopeStack.withNewScope(isSubqueryRoot = true) {
+ resolutionValidator.validate(listQuery.plan)
+ }
+
+ for (outerAttribute <- listQuery.outerAttrs) {
+ validate(outerAttribute)
+ }
+ }
+
+ private def validateExists(exists: Exists): Unit = {
+ attributeScopeStack.withNewScope(isSubqueryRoot = true) {
+ resolutionValidator.validate(exists.plan)
+ }
+
+ for (outerAttribute <- exists.outerAttrs) {
+ validate(outerAttribute)
+ }
+ }
+
private def validateArraysZip(arraysZip: ArraysZip): Unit = {
arraysZip.children.foreach(validate)
arraysZip.names.foreach(validate)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala
index d6fbdad46fb1b..2202d90a04a77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala
@@ -17,44 +17,30 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import java.util.ArrayDeque
+import java.util.{ArrayDeque, HashMap, HashSet}
+
+import scala.annotation.nowarn
+import scala.collection.mutable
+import scala.jdk.CollectionConverters._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.{
withPosition,
FunctionResolution,
GetViewColumnByNameAndOrdinal,
+ TypeCoercionValidation,
UnresolvedAlias,
UnresolvedAttribute,
UnresolvedFunction,
UnresolvedStar,
UpCastResolution
}
-import org.apache.spark.sql.catalyst.expressions.{
- Alias,
- AttributeReference,
- BinaryArithmetic,
- ConditionalExpression,
- CreateNamedStruct,
- DateAddYMInterval,
- Expression,
- ExtractIntervalPart,
- GetTimeField,
- Literal,
- MakeTimestamp,
- NamedExpression,
- Predicate,
- RuntimeReplaceable,
- TimeAdd,
- TimeZoneAwareExpression,
- UnaryMinus,
- UnresolvedNamedLambdaVariable,
- UpCast
-}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.MetadataBuilder
@@ -85,13 +71,36 @@ class ExpressionResolver(
extends TreeNodeResolver[Expression, Expression]
with ProducesUnresolvedSubtree
with ResolvesExpressionChildren {
- private val isLcaEnabled = conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)
/**
- * This is a flag indicating that we are resolving top of [[Project]] list. Otherwise extra
- * [[Alias]]es have to be stripped away.
+ * This field stores referenced attributes from the most recently resolved expression tree. It is
+ * populated in [[resolveExpressionTreeInOperatorImpl]] when [[ExpressionTreeTraversal]] is
+ * popped from the stack. It is the responsibility of the parent operator resolver to collect
+ * referenced attributes, before this field is overwritten.
+ */
+ private var lastReferencedAttributes: Option[HashMap[ExprId, Attribute]] = None
+
+ /**
+ * This field stores invalid expressions in the context of the parent operator from the most
+ * recently resolved expression tree. It is populated in [[resolveExpressionTreeInOperatorImpl]]
+ * when [[ExpressionTreeTraversal]] is popped from the stack. It is the responsibility of the
+ * parent operator resolver to collect invalid expressions, before this field is overwritten.
+ */
+ private var lastInvalidExpressionsInTheContextOfOperator: Option[Seq[Expression]] = None
+
+ /**
+ * This field contains the aliases of [[AggregateExpression]]s that were extracted during the
+ * most recent expression tree resolution. It is populated in
+ * [[resolveExpressionTreeInOperatorImpl]] when [[ExpressionTreeTraversal]] is popped from the
+ * stack.
*/
- private var isTopOfProjectList: Boolean = false
+ private var lastExtractedAggregateExpressionAliases: Option[Seq[Alias]] = None
+
+ /**
+ * This is a flag indicating that we are re-analyzing a resolved [[OuterReference]] subtree. It's
+ * managed by [[handleResolvedOuterReference]].
+ */
+ private var inOuterReferenceSubtree: Boolean = false
/**
* The stack of parent operators which were encountered during the resolution of a certain
@@ -117,10 +126,11 @@ class ExpressionResolver(
*
* Project -> Project
*/
- private val parentOperators = new ArrayDeque[LogicalPlan]
private val expressionIdAssigner = new ExpressionIdAssigner
+ private val traversals = new ExpressionTreeTraversalStack
private val expressionResolutionContextStack = new ArrayDeque[ExpressionResolutionContext]
private val scopes = resolver.getNameScopes
+ private val subqueryRegistry = resolver.getSubqueryRegistry
private val aliasResolver = new AliasResolver(this)
private val timezoneAwareExpressionResolver = new TimezoneAwareExpressionResolver(this)
@@ -132,10 +142,10 @@ class ExpressionResolver(
this,
timezoneAwareExpressionResolver
)
- private val limitExpressionResolver = new LimitExpressionResolver
+ private val limitLikeExpressionValidator = new LimitLikeExpressionValidator
private val typeCoercionResolver = new TypeCoercionResolver(timezoneAwareExpressionResolver)
private val aggregateExpressionResolver =
- new AggregateExpressionResolver(this, timezoneAwareExpressionResolver)
+ new AggregateExpressionResolver(resolver, this, timezoneAwareExpressionResolver)
private val functionResolver = new FunctionResolver(
this,
timezoneAwareExpressionResolver,
@@ -145,6 +155,51 @@ class ExpressionResolver(
)
private val timeAddResolver = new TimeAddResolver(this, timezoneAwareExpressionResolver)
private val unaryMinusResolver = new UnaryMinusResolver(this, timezoneAwareExpressionResolver)
+ private val subqueryExpressionResolver = new SubqueryExpressionResolver(this, resolver)
+
+ /**
+ * Get the expression tree traversal stack.
+ */
+ def getExpressionTreeTraversals: ExpressionTreeTraversalStack = traversals
+
+ /**
+ * Get the expression resolution context stack.
+ */
+ def getExpressionResolutionContextStack: ArrayDeque[ExpressionResolutionContext] =
+ expressionResolutionContextStack
+
+ def getExpressionIdAssigner: ExpressionIdAssigner = expressionIdAssigner
+
+ /**
+ * Get [[NameScopeStack]] bound to the used [[Resolver]].
+ */
+ def getNameScopes: NameScopeStack = scopes
+
+ /**
+ * Get the [[TypeCoercionResolver]] which contains all the transformations for generic coercion.
+ */
+ def getGenericTypeCoercionResolver: TypeCoercionResolver = typeCoercionResolver
+
+ /**
+ * Returns all attributes that have been referenced during the most recent expression tree
+ * resolution.
+ */
+ def getLastReferencedAttributes: HashMap[ExprId, Attribute] =
+ lastReferencedAttributes.getOrElse(new HashMap[ExprId, Attribute])
+
+ /**
+ * Returns all invalid expressions in the context of the parent operator from the most recent
+ * expression tree resolution.
+ */
+ def getLastInvalidExpressionsInTheContextOfOperator: Seq[Expression] =
+ lastInvalidExpressionsInTheContextOfOperator.getOrElse(Seq.empty)
+
+ /**
+ * Returns all aliases of [[AggregateExpression]]s that were extracted during the most recent
+ * expression tree resolution.
+ */
+ def getLastExtractedAggregateExpressionAliases: Seq[Alias] =
+ lastExtractedAggregateExpressionAliases.getOrElse(Seq.empty)
/**
* Resolve `unresolvedExpression` which is a child of `parentOperator`. This is the main entry
@@ -206,7 +261,11 @@ class ExpressionResolver(
case unresolvedLiteral: Literal =>
resolveLiteral(unresolvedLiteral)
case unresolvedPredicate: Predicate =>
- predicateResolver.resolve(unresolvedPredicate)
+ resolvePredicate(unresolvedPredicate)
+ case unresolvedScalarSubquery: ScalarSubquery =>
+ subqueryExpressionResolver.resolveScalarSubquery(unresolvedScalarSubquery)
+ case unresolvedListQuery: ListQuery =>
+ subqueryExpressionResolver.resolveListQuery(unresolvedListQuery)
case unresolvedTimeAdd: TimeAdd =>
timeAddResolver.resolve(unresolvedTimeAdd)
case unresolvedUnaryMinus: UnaryMinus =>
@@ -227,6 +286,8 @@ class ExpressionResolver(
timezoneAwareExpressionResolver.resolve(unresolvedTimezoneExpression)
case unresolvedUpCast: UpCast =>
resolveUpCast(unresolvedUpCast)
+ case unresolvedCollation: UnresolvedCollation =>
+ resolveCollation(unresolvedCollation)
case expression: Expression =>
resolveExpressionGenericallyWithTypeCoercion(expression)
}
@@ -234,8 +295,8 @@ class ExpressionResolver(
preserveTags(unresolvedExpression, resolvedExpression)
popResolutionContext()
- if (!resolvedExpression.resolved) {
- throwSinglePassFailedToResolveExpression(resolvedExpression)
+ withPosition(unresolvedExpression) {
+ validateResolvedExpressionGenerically(resolvedExpression)
}
planLogger.logExpressionTreeResolution(unresolvedExpression, resolvedExpression)
@@ -244,41 +305,20 @@ class ExpressionResolver(
}
/**
- * Get the expression resolution context stack.
- */
- def getExpressionResolutionContextStack: ArrayDeque[ExpressionResolutionContext] = {
- expressionResolutionContextStack
- }
-
- def getExpressionIdAssigner: ExpressionIdAssigner = expressionIdAssigner
-
- /**
- * Get the most recent operator (bottommost) from the `parentOperators` stack.
+ * Resolve and validate the limit like expressions from either [[LocalLimit]], [[GlobalLimit]],
+ * [[Offset]] or [[Tail]] operator.
*/
- def getParentOperator: Option[LogicalPlan] = {
- if (parentOperators.size() > 0) {
- Some(parentOperators.peek())
- } else {
- None
- }
- }
-
- /**
- * Get [[NameScopeStack]] bound to the used [[Resolver]].
- */
- def getNameScopes: NameScopeStack = scopes
-
- /**
- * Resolve the limit expression from either a [[LocalLimit]] or a [[GlobalLimit]] operator.
- */
- def resolveLimitExpression(
- unresolvedLimitExpr: Expression,
- unresolvedLimit: LogicalPlan): Expression = {
- val resolvedLimitExpr = resolveExpressionTreeInOperator(
- unresolvedLimitExpr,
- unresolvedLimit
+ def resolveLimitLikeExpression(
+ unresolvedLimitLikeExpr: Expression,
+ partiallyResolvedLimitLike: LogicalPlan): Expression = {
+ val resolvedLimitLikeExpr = resolveExpressionTreeInOperator(
+ unresolvedLimitLikeExpr,
+ partiallyResolvedLimitLike
+ )
+ limitLikeExpressionValidator.validateLimitLikeExpr(
+ resolvedLimitLikeExpr,
+ partiallyResolvedLimitLike
)
- limitExpressionResolver.resolve(resolvedLimitExpr)
}
/**
@@ -303,52 +343,220 @@ class ExpressionResolver(
* ResolvedProjectList(
* expressions = [count(col1) as count(col1), 2 AS 2],
* hasAggregateExpressions = true, // because it contains `count(col1)` in the project list
- * hasAttributes = false // because it doesn't contain any [[AttributeReference]]s in the
- * // project list (only under the aggregate expression, please check
- * // [[AggregateExpressionResolver]] for more details).
+ * )
*/
def resolveProjectList(
- unresolvedProjectList: Seq[NamedExpression],
+ sourceUnresolvedProjectList: Seq[NamedExpression],
operator: LogicalPlan): ResolvedProjectList = {
- val projectListResolutionContext = new ExpressionResolutionContext
- val resolvedProjectList = unresolvedProjectList.flatMap {
+ val unresolvedProjectList = tryDrainLazySequences(sourceUnresolvedProjectList)
+
+ var hasAggregateExpressions = false
+ var hasLateralColumnAlias = false
+
+ val unresolvedProjectListWithStarsExpanded = unresolvedProjectList.flatMap {
case unresolvedStar: UnresolvedStar =>
resolveStar(unresolvedStar)
- case other =>
- val (resolvedElement, resolvedElementContext) =
- resolveExpressionTreeInOperatorImpl(other, operator)
- projectListResolutionContext.merge(resolvedElementContext)
- Seq(resolvedElement.asInstanceOf[NamedExpression])
+ case other => Seq(other)
+ }
+
+ val resolvedProjectList = unresolvedProjectListWithStarsExpanded.flatMap { expression =>
+ val (resolvedElement, resolvedElementContext) = {
+ resolveExpressionTreeInOperatorImpl(
+ expression,
+ operator,
+ inProjectList = true
+ )
+ }
+
+ hasAggregateExpressions |= resolvedElementContext.hasAggregateExpressions
+ hasLateralColumnAlias |= resolvedElementContext.hasLateralColumnAlias
+
+ Seq(resolvedElement.asInstanceOf[NamedExpression])
}
+
ResolvedProjectList(
expressions = resolvedProjectList,
- hasAggregateExpressions = projectListResolutionContext.hasAggregateExpressionsInASubtree,
- hasAttributes = projectListResolutionContext.hasAttributeInASubtree,
- hasLateralColumnAlias = projectListResolutionContext.hasLateralColumnAlias
+ hasAggregateExpressions = hasAggregateExpressions,
+ hasLateralColumnAlias = hasLateralColumnAlias
)
}
/**
- * Resolves [[Expression]] by resolving its children and applying generic type coercion
- * transformations to the resulting expression. This resolution method is used for nodes that
- * require type coercion on top of [[resolveExpressionGenerically]].
+ * Resolve aggregate expressions in [[Aggregate]] operator.
+ *
+ * The [[Aggregate]] list can contain different unresolved expressions before the resolution,
+ * which will be resolved using generic [[resolve]]. However, [[UnresolvedStar]] is a special
+ * case, because it is expanded into a sequence of [[NamedExpression]]s. Because of that this
+ * method returns a sequence and doesn't conform to generic [[resolve]] interface - it's called
+ * directly from the [[AggregateResolver]] during [[Aggregate]] resolution.
+ *
+ * Besides resolution, we do the following:
+ * - If there is a [[UnresolvedStar]] in the list we set `hasStar` to true in order to throw
+ * if there are any ordinals in grouping expressions.
+ * Example of an invalid query:
+ * {{{ SELECT * FROM VALUES(1) GROUP BY 1; }}}
+ *
+ * - If there is an expression which has aggregate function in its subtree, we add it to the
+ * `expressionsWithAggregateFunctions` list in order to throw if there is any ordinal in
+ * grouping expressions which references that aggregate expression.
+ * Example of an invalid query:
+ * {{{ SELECT count(col1) FROM VALUES(1) GROUP BY 1; }}}
+ *
+ * - If the resolved expression is an [[Alias]], add it to
+ * `scopes.current.topAggregateExpressionsByAliasName` so it can be used for grouping
+ * expressions resolution, if needed.
+ * Example of a query with an [[Alias]]:
+ * 1. Implicit alias:
+ * {{{ SELECT col1 + col2 FROM VALUES(1, 2) GROUP BY `(col1 + col2)`; }}}
+ * 2. Explicit alias:
+ * {{{ SELECT 1 AS column GROUP BY column; }}}
+ *
+ * While resolving the list, we have to keep track of all the expressions that don't have
+ * [[AggregateExpression]]s in their subtrees (`expressionsWithoutAggregates`) and whether any of
+ * aggregate expressions (that are not `expressionsWithoutAggregates`) has attributes in the
+ * subtree outside of [[AggregateExpressions]]s (`hasAttributeOutsideOfAggregateExpressions`).
+ * This is used when resolving `GROUP BY ALL` in the [[AggregateResolver.resolveGroupByAll]].
+ *
+ * @returns List of resolved expressions, list of expressions that don't have
+ * [[AggregateExpression]] in their subtrees, if any of resolved expressions have
+ * attributes in the subtree that are not under an [[AggregateExpression]], if any of
+ * expressions is a star (`*`) and list of indices of expressions that have aggregate
+ * functions in the subtree encapsulated in [[ResolvedAggregateExpressions]].
*/
- def resolveExpressionGenericallyWithTypeCoercion(expression: Expression): Expression = {
- val expressionWithResolvedChildren = withResolvedChildren(expression, resolve)
- typeCoercionResolver.resolve(expressionWithResolvedChildren)
+ def resolveAggregateExpressions(
+ sourceUnresolvedAggregateExpressions: Seq[NamedExpression],
+ unresolvedAggregate: Aggregate): ResolvedAggregateExpressions = {
+ val unresolvedAggregateExpressions = tryDrainLazySequences(sourceUnresolvedAggregateExpressions)
+
+ val expressionsWithoutAggregates = new mutable.ArrayBuffer[NamedExpression]
+ val expressionIndexesWithAggregateFunctions = new HashSet[Int]
+ var hasAttributeOutsideOfAggregateExpressions = false
+ var hasStar = false
+
+ val unresolvedAggregateExpressionsWithStarsExpanded = unresolvedAggregateExpressions.flatMap {
+ case unresolvedStar: UnresolvedStar =>
+ hasStar = true
+ resolveStar(unresolvedStar)
+ case other => Seq(other)
+ }
+
+ val resolvedAggregateExpressions =
+ unresolvedAggregateExpressionsWithStarsExpanded.zipWithIndex.flatMap {
+ case (expression, index) =>
+ val (resolvedElement, resolvedElementContext) = resolveExpressionTreeInOperatorImpl(
+ expression,
+ unresolvedAggregate,
+ inProjectList = true
+ )
+
+ resolvedElement match {
+ case alias: Alias =>
+ scopes.current.addTopAggregateExpression(alias)
+ case other =>
+ }
+
+ if (resolvedElementContext.hasAggregateExpressions) {
+ expressionIndexesWithAggregateFunctions.add(index)
+ hasAttributeOutsideOfAggregateExpressions |=
+ resolvedElementContext.hasAttributeOutsideOfAggregateExpressions
+ } else {
+ expressionsWithoutAggregates += resolvedElement.asInstanceOf[NamedExpression]
+ }
+
+ Seq(resolvedElement.asInstanceOf[NamedExpression])
+ }
+
+ val isLcaEnabled = conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)
+ if (isLcaEnabled && scopes.current.lcaRegistry.getAliasDependencyLevels().size() > 1) {
+ throw new ExplicitlyUnsupportedResolverFeature("LateralColumnAlias in aggregate expressions")
+ }
+
+ ResolvedAggregateExpressions(
+ expressions = resolvedAggregateExpressions,
+ resolvedExpressionsWithoutAggregates = expressionsWithoutAggregates.toSeq,
+ hasAttributeOutsideOfAggregateExpressions = hasAttributeOutsideOfAggregateExpressions,
+ hasStar = hasStar,
+ expressionIndexesWithAggregateFunctions = expressionIndexesWithAggregateFunctions
+ )
+ }
+
+ /**
+ * Resolve grouping expressions in [[Aggregate]] operator.
+ *
+ * It's done for every expression using the `resolveExpressionTreeInOperatorImpl`. For cases where
+ * grouping is done based on aliases the resolution is following:
+ * - If the expression can be resolved using the child's output (`scopes.current.output`),
+ * resolve it that way.
+ * Example:
+ * {{{ SELECT col1 FROM VALUES(1) GROUP BY `col1`; }}}
+ *
+ * - If not, try to resolve it as a top level [[Alias]] (which was populated during the
+ * resolution of the aggregate expressions).
+ * Example:
+ * 1. Group by implicit alias
+ * {{{ SELECT concat_ws(' ', 'a', 'b') GROUP BY `concat_ws( , a, b)`; }}}
+ * 2. Group by explicit alias
+ * {{{ SELECT col1 AS column_1 FROM VALUES(1) GROUP BY column_1; }}}
+ */
+ def resolveGroupingExpressions(
+ sourceUnresolvedGroupingExpressions: Seq[Expression],
+ unresolvedAggregate: Aggregate): Seq[Expression] = {
+ val unresolvedGroupingExpressions = tryDrainLazySequences(sourceUnresolvedGroupingExpressions)
+
+ unresolvedGroupingExpressions.map { expression =>
+ val (resolvedExpression, _) = resolveExpressionTreeInOperatorImpl(
+ expression,
+ unresolvedAggregate,
+ resolvingGroupingExpressions = true
+ )
+
+ resolvedExpression
+ }
+ }
+
+ /**
+ * Validate if `expression` is under supported operator or not. In case it's not, add `expression`
+ * to the [[ExpressionTreeTraversal.invalidExpressionsInTheContextOfOperator]] list to throw
+ * error later, when [[getLastInvalidExpressionsInTheContextOfOperator]] is called by the
+ * [[Resolver]].
+ */
+ def validateExpressionUnderSupportedOperator(expression: Expression): Unit = {
+ if (UnsupportedExpressionInOperatorValidation.isExpressionInUnsupportedOperator(
+ expression,
+ traversals.current.parentOperator
+ )) {
+ traversals.current.invalidExpressionsInTheContextOfOperator.add(expression)
+ }
}
private def resolveExpressionTreeInOperatorImpl(
unresolvedExpression: Expression,
- parentOperator: LogicalPlan): (Expression, ExpressionResolutionContext) = {
- this.parentOperators.push(parentOperator)
- expressionResolutionContextStack.push(new ExpressionResolutionContext)
- try {
- val resolvedExpression = resolve(unresolvedExpression)
- (resolvedExpression, expressionResolutionContextStack.peek())
- } finally {
- expressionResolutionContextStack.pop()
- this.parentOperators.pop()
+ parentOperator: LogicalPlan,
+ inProjectList: Boolean = false,
+ resolvingGroupingExpressions: Boolean = false
+ ): (Expression, ExpressionResolutionContext) = {
+ traversals.withNewTraversal(parentOperator) {
+ expressionResolutionContextStack.push(
+ new ExpressionResolutionContext(
+ isRoot = true,
+ isTopOfProjectList = inProjectList,
+ resolvingGroupingExpressions = resolvingGroupingExpressions
+ )
+ )
+
+ try {
+ val resolvedExpression = resolve(unresolvedExpression)
+
+ lastReferencedAttributes = Some(traversals.current.referencedAttributes)
+ lastInvalidExpressionsInTheContextOfOperator =
+ Some(traversals.current.invalidExpressionsInTheContextOfOperator.asScala.toSeq)
+ lastExtractedAggregateExpressionAliases =
+ Some(traversals.current.extractedAggregateExpressionAliases.asScala.toSeq)
+
+ (resolvedExpression, expressionResolutionContextStack.peek())
+ } finally {
+ expressionResolutionContextStack.pop()
+ }
}
}
@@ -365,6 +573,8 @@ class ExpressionResolver(
throw new ExplicitlyUnsupportedResolverFeature("Star outside of Project list")
case attributeReference: AttributeReference =>
handleResolvedAttributeReference(attributeReference)
+ case outerReference: OuterReference =>
+ handleResolvedOuterReference(outerReference)
case _: UnresolvedNamedLambdaVariable =>
throw new ExplicitlyUnsupportedResolverFeature("Lambda variables")
case _ =>
@@ -373,6 +583,32 @@ class ExpressionResolver(
}
}
+ /**
+ * [[UnresolvedStar]] resolution relies on the [[NameScope]]'s ability to get the attributes by a
+ * multipart name ([[UnresolvedStar]]'s `target` field):
+ *
+ * - Star target is defined:
+ *
+ * {{{
+ * SELECT t.* FROM VALUES (1) AS t;
+ * ->
+ * Project [col1#19]
+ * }}}
+ *
+ *
+ * - Star target is not defined:
+ *
+ * {{{
+ * SELECT * FROM (SELECT 1 as col1), (SELECT 2 as col2);
+ * ->
+ * Project [col1#19, col2#20]
+ * }}}
+ */
+ private def resolveStar(unresolvedStar: UnresolvedStar): Seq[NamedExpression] =
+ withPosition(unresolvedStar) {
+ scopes.current.expandStar(unresolvedStar)
+ }
+
/**
* [[UnresolvedAttribute]] resolution relies on [[NameScope]] to lookup the attribute by its
* multipart name. The resolution can result in three different outcomes which are handled in the
@@ -391,94 +627,170 @@ class ExpressionResolver(
* attribute that is a lateral column alias reference. In that case we mark the referenced
* attribute as referenced and tag the LCA attribute for further [[Alias]] resolution.
*
+ * In case that attribute is resolved as a literal function (i.e. result is [[CurrentDate]]),
+ * perform additional resolution on it.
+ *
* If the attribute is at the top of the project list (which is indicated by
- * [[isTopOfProjectList]]), we preserve the [[Alias]] or remove it otherwise.
+ * [[ExpressionResolutionContext.isTopOfProjectList]]), we preserve the [[Alias]] or remove it
+ * otherwise.
*
- * Finally, we remap the expression ID of a top [[NamedExpression]]. It's not necessary to remap
- * the expression ID of lower expressions, because they already got the appropriate ID from the
+ * Finally, we remap the expression ID of a top [[Alias]]. It's not necessary to remap the
+ * expression ID of lower expressions, because they already got the appropriate ID from the
* current scope output in [[resolveMultipartName]].
*/
private def resolveAttribute(unresolvedAttribute: UnresolvedAttribute): Expression =
withPosition(unresolvedAttribute) {
- expressionResolutionContextStack.peek().hasAttributeInASubtree = true
+ val isLcaEnabled = conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)
+ val expressionResolutionContext = expressionResolutionContextStack.peek()
- val nameTarget: NameTarget =
- scopes.top.resolveMultipartName(unresolvedAttribute.nameParts, isLcaEnabled)
+ val nameTarget: NameTarget = scopes.resolveMultipartName(
+ multipartName = unresolvedAttribute.nameParts,
+ canLaterallyReferenceColumn = canLaterallyReferenceColumn(isLcaEnabled),
+ canReferenceAggregateExpressionAliases = (
+ expressionResolutionContextStack
+ .peek()
+ .resolvingGroupingExpressions && conf.groupByAliases
+ ),
+ canResolveNameByHiddenOutput = canResolveNameByHiddenOutput
+ )
val candidate = nameTarget.pickCandidate(unresolvedAttribute)
+ expressionResolutionContext.hasAttributeOutsideOfAggregateExpressions = true
+ if (nameTarget.isOuterReference) {
+ expressionResolutionContext.hasOuterReferences = true
+ } else {
+ expressionResolutionContext.hasLocalReferences = true
+ }
+
if (isLcaEnabled) {
nameTarget.lateralAttributeReference match {
case Some(lateralAttributeReference) =>
- scopes.top.lcaRegistry
+ scopes.current.lcaRegistry
.markAttributeLaterallyReferenced(lateralAttributeReference)
candidate.setTagValue(ExpressionResolver.SINGLE_PASS_IS_LCA, ())
- expressionResolutionContextStack.peek().hasLateralColumnAlias = true
+ expressionResolutionContext.hasLateralColumnAlias = true
case None =>
}
}
+ tryAddReferencedAttribute(candidate)
+
+ val candidateOrLiteralFunction = candidate match {
+ case currentDate: CurrentDate =>
+ timezoneAwareExpressionResolver.resolve(currentDate)
+ case other => other
+ }
+
val properlyAliasedExpressionTree =
- if (isTopOfProjectList && nameTarget.aliasName.isDefined) {
- Alias(candidate, nameTarget.aliasName.get)()
+ if (expressionResolutionContext.isTopOfProjectList && nameTarget.aliasName.isDefined) {
+ Alias(candidateOrLiteralFunction, nameTarget.aliasName.get)()
} else {
- candidate
+ candidateOrLiteralFunction
}
properlyAliasedExpressionTree match {
- case namedExpression: NamedExpression =>
- expressionIdAssigner.mapExpression(namedExpression)
+ case alias: Alias =>
+ expressionIdAssigner.mapExpression(alias)
case _ =>
properlyAliasedExpressionTree
}
}
+ private def canResolveNameByHiddenOutput = traversals.current.parentOperator match {
+ case operator @ (_: Filter | _: Sort) => true
+ case other => false
+ }
+
/**
* [[AttributeReference]] is already resolved if it's passed to us from DataFrame `col(...)`
* function, for example.
+ *
+ * After mapping the [[AttributeReference]] to a correct [[ExprId]], we need to assert that the
+ * attribute exists in current [[NameScope]]. If the attribute in the current scope is nullable,
+ * we need to preserve this nullability in the [[AttributeReference]] as well. This is necessary
+ * because of the following case:
+ *
+ * {{{
+ * val df1 = Seq((1, 1)).toDF("a", "b")
+ * val df2 = Seq((2, 2)).toDF("a", "b")
+ * df1.join(df2, df1("a") === df2("a"), "outer")
+ * .select(coalesce(df1("a"), df1("b")), coalesce(df2("a"), df2("b"))),
+ * }}}
+ *
+ * Because of the outer join, after df1.join(df2), the output of the join will have nullable
+ * attributes "a" and "b" that come from df2. When these attributes are referenced in
+ * `select(coalesce(df2("a"), df2("b"))`, they need to inherit nullability of join's output, or
+ * retain nullability of [[AttributeReference]], if it was true.
+ *
+ * Without this, a nullable column's nullable field can be actually set as non-nullable, which
+ * can cause illegal optimization (e.g., NULL propagation) and wrong answers. See SPARK-13484 and
+ * SPARK-13801 for the concrete queries of this case.
*/
private def handleResolvedAttributeReference(attributeReference: AttributeReference) = {
+ val expressionResolutionContext = expressionResolutionContextStack.peek()
+
+ expressionResolutionContext.hasAttributeOutsideOfAggregateExpressions = true
+
val strippedAttributeReference = tryStripAmbiguousSelfJoinMetadata(attributeReference)
- val resultAttribute = expressionIdAssigner.mapExpression(strippedAttributeReference)
- if (!scopes.top.hasAttributeWithId(resultAttribute.exprId)) {
- throw new ExplicitlyUnsupportedResolverFeature("DataFrame missing attribute propagation")
+ val resultAttribute = if (!inOuterReferenceSubtree) {
+ expressionResolutionContext.hasLocalReferences = true
+
+ expressionIdAssigner.mapExpression(strippedAttributeReference)
+ } else {
+ expressionResolutionContext.hasOuterReferences = true
+
+ expressionIdAssigner.mapOuterReference(strippedAttributeReference)
}
- resultAttribute
+ val existingAttributeWithId = scopes.current.getAttributeById(resultAttribute.exprId)
+ val resultAttributeWithNullability = if (existingAttributeWithId.isEmpty) {
+ resultAttribute
+ } else {
+ val nullability = existingAttributeWithId.get.nullable || resultAttribute.nullable
+ resultAttribute.withNullability(nullability)
+ }
+
+ tryAddReferencedAttribute(resultAttributeWithNullability)
+
+ resultAttributeWithNullability
}
/**
- * [[UnresolvedStar]] resolution relies on the [[NameScope]]'s ability to get the attributes by a
- * multipart name ([[UnresolvedStar]]'s `target` field):
- *
- * - Star target is defined:
- *
- * {{{
- * SELECT t.* FROM VALUES (1) AS t;
- * ->
- * Project [col1#19]
- * }}}
- *
- *
- * - Star target is not defined:
- *
- * {{{
- * SELECT * FROM (SELECT 1 as col1), (SELECT 2 as col2);
- * ->
- * Project [col1#19, col2#20]
- * }}}
+ * While handling the resolved [[OuterReference]] we need to set [[inOuterReferenceSubtree]] to
+ * `true` to correctly remap [[AttributeReference]] expression IDs using [[ExpressionIdAssigner]]
+ * using outer expression ID mapping.
*/
- def resolveStar(unresolvedStar: UnresolvedStar): Seq[NamedExpression] =
- withPosition(unresolvedStar) {
- scopes.top.expandStar(unresolvedStar)
+ private def handleResolvedOuterReference(outerReference: OuterReference): Expression = {
+ inOuterReferenceSubtree = true
+ try {
+ OuterReference(e = resolve(outerReference.e).asInstanceOf[NamedExpression])
+ } finally {
+ inOuterReferenceSubtree = false
}
+ }
/**
* [[Literal]] resolution doesn't require any specific resolution logic at this point.
*/
private def resolveLiteral(literal: Literal): Expression = literal
+ /**
+ * Resolve [[Predicate]] expression using [[PredicateResolver]]. Subquery expressions are a
+ * special case and require special resolution logic.
+ */
+ private def resolvePredicate(unresolvedPredicate: Predicate): Expression = {
+ unresolvedPredicate match {
+ case unresolvedInSubquery: InSubquery =>
+ subqueryExpressionResolver.resolveInSubquery(unresolvedInSubquery)
+ case unresolvedExists: Exists =>
+ subqueryExpressionResolver.resolveExists(unresolvedExists)
+ case _ =>
+ predicateResolver.resolve(unresolvedPredicate)
+ }
+ }
+
/**
* The [[GetViewColumnByNameAndOrdinal]] is a special internal expression that is placed by the
* [[SessionCatalog]] in the top [[Project]] operator of the freshly reconstructed unresolved
@@ -540,7 +852,7 @@ class ExpressionResolver(
*/
private def resolveGetViewColumnByNameAndOrdinal(
getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal): Expression = {
- val candidates = scopes.top.findAttributesByName(getViewColumnByNameAndOrdinal.colName)
+ val candidates = scopes.current.findAttributesByName(getViewColumnByNameAndOrdinal.colName)
if (candidates.length != getViewColumnByNameAndOrdinal.expectedNumCandidates) {
throw QueryCompilationErrors.incompatibleViewSchemaChangeError(
getViewColumnByNameAndOrdinal.viewName,
@@ -569,44 +881,44 @@ class ExpressionResolver(
}
/**
- * Resolves [[Expression]] by calling [[timezoneAwareExpressionResolver]] to resolve
- * expression's children and apply timezone if needed. Applies generic type coercion
- * rules to the result.
+ * Collation resolution requires resolving its collation name using [[CollationFactory]].
*/
- private def resolveExpressionGenericallyWithTimezoneWithTypeCoercion(
- timezoneAwareExpression: TimeZoneAwareExpression): Expression = {
- val expressionWithTimezone = timezoneAwareExpressionResolver.resolve(timezoneAwareExpression)
- typeCoercionResolver.resolve(expressionWithTimezone)
+ private def resolveCollation(unresolvedCollation: UnresolvedCollation): Expression = {
+ ResolvedCollation(
+ CollationFactory.resolveFullyQualifiedName(unresolvedCollation.collationName.toArray)
+ )
}
- /**
- * Resolves [[Expression]] only by resolving its children. This resolution method is used for
- * nodes that don't require any special resolution other than resolving its children.
- */
- private def resolveExpressionGenerically(expression: Expression): Expression =
- withResolvedChildren(expression, resolve)
+ private def pushResolutionContext(): Unit = {
+ val parentContext = expressionResolutionContextStack.peek()
+ expressionResolutionContextStack.push(ExpressionResolutionContext.createChild(parentContext))
+ }
private def popResolutionContext(): Unit = {
- val currentExpressionResolutionContext = expressionResolutionContextStack.pop()
- expressionResolutionContextStack.peek().merge(currentExpressionResolutionContext)
+ val childContext = expressionResolutionContextStack.pop()
+ expressionResolutionContextStack.peek().mergeChild(childContext)
}
- private def pushResolutionContext(): Unit = {
- isTopOfProjectList = expressionResolutionContextStack
- .size() == 1 && parentOperators.peek().isInstanceOf[Project]
-
- expressionResolutionContextStack.push(new ExpressionResolutionContext)
+ private def tryAddReferencedAttribute(expression: Expression) = expression match {
+ case attribute: Attribute =>
+ traversals.current.referencedAttributes.put(attribute.exprId, attribute)
+ case extractValue: ExtractValue =>
+ extractValue.foreach {
+ case attribute: Attribute =>
+ traversals.current.referencedAttributes.put(attribute.exprId, attribute)
+ case _ =>
+ }
+ case _ =>
}
- private def tryPopSinglePassSubtreeBoundary(unresolvedExpression: Expression): Boolean = {
- if (unresolvedExpression
- .getTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY)
- .isDefined) {
- unresolvedExpression.unsetTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY)
- true
- } else {
- false
- }
+ /**
+ * Returns true if LateralColumnAlias resolution is enabled, current operator is not created
+ * because of a generated column and current expression is a grouping one (grouping expressions)
+ * can't reference an LCA.
+ */
+ private def canLaterallyReferenceColumn(isLcaEnabled: Boolean): Boolean = {
+ isLcaEnabled &&
+ !expressionResolutionContextStack.peek().resolvingGroupingExpressions
}
/**
@@ -644,6 +956,67 @@ class ExpressionResolver(
}
}
+ /**
+ * Resolves [[Expression]] only by resolving its children. This resolution method is used for
+ * nodes that don't require any special resolution other than resolving its children.
+ */
+ private def resolveExpressionGenerically(expression: Expression): Expression =
+ withResolvedChildren(expression, resolve _)
+
+ /**
+ * Resolves [[Expression]] by resolving its children and applying generic type coercion
+ * transformations to the resulting expression. This resolution method is used for nodes that
+ * require type coercion on top of [[resolveExpressionGenerically]].
+ */
+ private def resolveExpressionGenericallyWithTypeCoercion(expression: Expression): Expression = {
+ val expressionWithResolvedChildren = withResolvedChildren(expression, resolve _)
+ typeCoercionResolver.resolve(expressionWithResolvedChildren)
+ }
+
+ /**
+ * Resolves [[Expression]] by calling [[timezoneAwareExpressionResolver]] to resolve
+ * expression's children and apply timezone if needed. Applies generic type coercion
+ * rules to the result.
+ */
+ private def resolveExpressionGenericallyWithTimezoneWithTypeCoercion(
+ timezoneAwareExpression: TimeZoneAwareExpression): Expression = {
+ val expressionWithTimezone = timezoneAwareExpressionResolver.resolve(timezoneAwareExpression)
+ typeCoercionResolver.resolve(expressionWithTimezone)
+ }
+
+ private def validateResolvedExpressionGenerically(resolvedExpression: Expression): Unit = {
+ if (resolvedExpression.checkInputDataTypes().isFailure) {
+ TypeCoercionValidation.failOnTypeCheckResult(resolvedExpression)
+ }
+
+ if (!resolvedExpression.resolved) {
+ throwSinglePassFailedToResolveExpression(resolvedExpression)
+ }
+
+ validateExpressionUnderSupportedOperator(resolvedExpression)
+ }
+
+ /**
+ * Transform list to a non-lazy one. This is needed in order to enforce determinism in
+ * single-pass resolver. Example:
+ *
+ * val groupByColumns = LazyList(col("key"))
+ * val df = Seq((1, 2)).toDF("key", "value")
+ * df.groupBy(groupByCols: _*)
+ *
+ * For this case, it is necessary to transform `LazyList` (lazy object) to `List` (non-lazy).
+ * Other object which has to be transformed is `Stream`.
+ */
+ private def tryDrainLazySequences(list: Seq[Expression]): Seq[Expression] = {
+ @nowarn("cat=deprecation")
+ val result = list match {
+ case lazyObject @ (_: LazyList[_] | _: Stream[_]) =>
+ lazyObject.toList
+ case other => other
+ }
+ result
+ }
+
private def throwUnsupportedSinglePassAnalyzerFeature(unresolvedExpression: Expression): Nothing =
throw QueryCompilationErrors.unsupportedSinglePassAnalyzerFeature(
s"${unresolvedExpression.getClass} expression resolution"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionTreeTraversal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionTreeTraversal.scala
new file mode 100644
index 0000000000000..ef3feb11c3998
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionTreeTraversal.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import java.util.{ArrayDeque, ArrayList, HashMap}
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, ExprId}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * Properties of a current expression tree traversal.
+ *
+ * @param parentOperator The parent operator of the current expression tree.
+ * @param invalidExpressionsInTheContextOfOperator The expressions that are invalid in the context
+ * of the current expression tree and its parent operator.
+ * @param referencedAttributes All attributes that are referenced during the resolution of
+ * expression trees.
+ * @param extractedAggregateExpressionAliases The aliases of the [[AggregateExpressions]] that are
+ * extracted during the resolution of expression trees.
+ */
+case class ExpressionTreeTraversal(
+ parentOperator: LogicalPlan,
+ invalidExpressionsInTheContextOfOperator: ArrayList[Expression] = new ArrayList[Expression],
+ referencedAttributes: HashMap[ExprId, Attribute] = new HashMap[ExprId, Attribute],
+ extractedAggregateExpressionAliases: ArrayList[Alias] = new ArrayList[Alias]
+)
+
+/**
+ * The stack of expression tree traversal properties which are accumulated during the resolution
+ * of a certain expression tree. This is filled by the
+ * [[ExpressionResolver.resolveExpressionTreeInOperatorImpl]],
+ * and will usually have size 1. However, in case of subquery expressions we would call
+ * [[ExpressionResolver.resolveExpressionTreeInOperatorImpl]] several times recursively
+ * for each expression tree in the operator tree -> expression tree -> operator tree ->
+ * expression tree -> ... chain. Consider this example:
+ *
+ * {{{
+ * SELECT
+ * col1
+ * FROM
+ * VALUES (1) AS t1
+ * WHERE EXISTS (
+ * SELECT
+ * *
+ * FROM
+ * VALUES (2) AS t2
+ * WHERE
+ * (SELECT col1 FROM VALUES (3) AS t3) == t1.col1
+ * )
+ * }}}
+ *
+ * We would have 3 nested stack entries for while resolving the lower scalar subquery (with the `t3`
+ * table).
+ */
+class ExpressionTreeTraversalStack {
+ private val stack = new ArrayDeque[ExpressionTreeTraversal]
+
+ /**
+ * Current expression tree traversal properties. Must exist when resolving an expression tree.
+ */
+ def current: ExpressionTreeTraversal = {
+ if (stack.isEmpty) {
+ throw SparkException.internalError("No current expression tree traversal")
+ }
+ stack.peek()
+ }
+
+ /**
+ * Pushes a new [[ExpressionTreeTraversal]] object, executes the `body` and finally pops the
+ * traversal from the stack.
+ */
+ def withNewTraversal[R](parentOperator: LogicalPlan)(body: => R): R = {
+ stack.push(ExpressionTreeTraversal(parentOperator = parentOperator))
+ try {
+ body
+ } finally {
+ stack.pop()
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FilterResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FilterResolver.scala
new file mode 100644
index 0000000000000..d6192ee3d273d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FilterResolver.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.withPosition
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
+import org.apache.spark.sql.types.BooleanType
+
+/**
+ * Resolves [[Filter]] node and its condition.
+ */
+class FilterResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
+ extends TreeNodeResolver[Filter, LogicalPlan]
+ with ResolvesNameByHiddenOutput {
+ override protected val scopes: NameScopeStack = resolver.getNameScopes
+
+ /**
+ * Resolve [[Filter]] by resolving its child and its condition. If an attribute that is used in
+ * the condition is not present in child's output, we need to try and find the attribute in
+ * hidden output. If found, update child's output and place a [[Project]] node on top of original
+ * [[Filter]] with the original output of a [[Filter]]'s child.
+ *
+ * See [[ResolvesNameByHiddenOutput]] doc for more context.
+ */
+ override def resolve(unresolvedFilter: Filter): LogicalPlan = {
+ val resolvedChild = resolver.resolve(unresolvedFilter.child)
+
+ val partiallyResolvedFilter = unresolvedFilter.copy(child = resolvedChild)
+ val resolvedCondition = expressionResolver.resolveExpressionTreeInOperator(
+ partiallyResolvedFilter.condition,
+ partiallyResolvedFilter
+ )
+
+ val referencedAttributes = expressionResolver.getLastReferencedAttributes
+
+ val resolvedFilter = Filter(resolvedCondition, resolvedChild)
+
+ checkValidFilter(unresolvedFilter, resolvedFilter)
+
+ val missingAttributes: Seq[Attribute] =
+ scopes.current.resolveMissingAttributesByHiddenOutput(referencedAttributes)
+ val resolvedChildWithMissingAttributes =
+ insertMissingExpressions(resolvedChild, missingAttributes)
+ val finalFilter = resolvedFilter.copy(child = resolvedChildWithMissingAttributes)
+
+ retainOriginalOutput(finalFilter, missingAttributes)
+ }
+
+ private def checkValidFilter(unresolvedFilter: Filter, resolvedFilter: Filter): Unit = {
+ withPosition(unresolvedFilter) {
+ val invalidExpressions = expressionResolver.getLastInvalidExpressionsInTheContextOfOperator
+ if (invalidExpressions.nonEmpty) {
+ throwInvalidWhereCondition(resolvedFilter, invalidExpressions)
+ }
+
+ if (resolvedFilter.condition.dataType != BooleanType) {
+ throwDataTypeMismatchFilterNotBoolean(resolvedFilter)
+ }
+ }
+ }
+
+ private def throwInvalidWhereCondition(
+ filter: Filter,
+ invalidExpressions: Seq[Expression]): Nothing = {
+ throw new AnalysisException(
+ errorClass = "INVALID_WHERE_CONDITION",
+ messageParameters = Map(
+ "condition" -> toSQLExpr(filter.condition),
+ "expressionList" -> invalidExpressions.map(_.sql).mkString(", ")
+ )
+ )
+ }
+
+ private def throwDataTypeMismatchFilterNotBoolean(filter: Filter): Nothing =
+ throw new AnalysisException(
+ errorClass = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN",
+ messageParameters = Map(
+ "sqlExpr" -> makeCommaSeparatedExpressionString(filter.expressions),
+ "filter" -> toSQLExpr(filter.condition),
+ "type" -> toSQLType(filter.condition.dataType)
+ )
+ )
+
+ private def makeCommaSeparatedExpressionString(expressions: Seq[Expression]): String = {
+ expressions.map(toSQLExpr).mkString(", ")
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala
index 04fd03d17a864..e02cd600b8881 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala
@@ -97,7 +97,8 @@ class FunctionResolver(
if (isCountStarExpansionAllowed(unresolvedFunction)) {
normalizeCountExpression(unresolvedFunction)
} else {
- withResolvedChildren(unresolvedFunction, expressionResolver.resolve)
+ withResolvedChildren(unresolvedFunction, expressionResolver.resolve _)
+ .asInstanceOf[UnresolvedFunction]
}
var resolvedFunction = functionResolution.resolveFunction(functionWithResolvedChildren)
@@ -116,7 +117,9 @@ class FunctionResolver(
// Since this [[InheritAnalysisRules]] node is created by
// [[FunctionResolution.resolveFunction]], we need to re-resolve its replacement
// expression.
- expressionResolver.resolveExpressionGenericallyWithTypeCoercion(inheritAnalysisRules)
+ val resolvedInheritAnalysisRules =
+ withResolvedChildren(inheritAnalysisRules, expressionResolver.resolve _)
+ typeCoercionResolver.resolve(resolvedInheritAnalysisRules)
case aggregateExpression: AggregateExpression =>
// In case `functionResolution.resolveFunction` produces a `AggregateExpression` we
// need to apply further resolution which is done in the
@@ -135,7 +138,7 @@ class FunctionResolver(
typeCoercionResolver.resolve(other)
}
- timezoneAwareExpressionResolver.withResolvedTimezoneCopyTags(
+ timezoneAwareExpressionResolver.withResolvedTimezone(
resolvedFunction,
conf.sessionLocalTimeZone
)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala
index 0c1ed75e1e15b..8d55a52bc0d64 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis.resolver
+import java.util.Random
+
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.{QueryPlanningTracker, SQLConfHelper}
@@ -37,9 +39,10 @@ import org.apache.spark.sql.internal.SQLConf
* - If the "spark.sql.analyzer.singlePassResolver.dualRunEnabled" is "true", the
* [[HybridAnalyzer]] will invoke the legacy analyzer and optionally _also_ the fixed-point
* one depending on the structure of the unresolved plan. This decision is based on which
- * features are supported by the single-pass Analyzer, and the checking is implemented in
- * the [[ResolverGuard]]. After that we validate the results using the following
- * logic:
+ * features are supported by the single-pass Analyzer, and the checking is implemented in the
+ * [[ResolverGuard]]. It's also determined if the query should be run in dual run mode by
+ * the [[SQLConf.ANALYZER_DUAL_RUN_SAMPLE_RATE]] flag value. After that we validate the
+ * results using the following logic:
* - If the fixed-point Analyzer fails and the single-pass one succeeds, we throw an
* appropriate exception (please check the
* [[QueryCompilationErrors.fixedPointFailedSinglePassSucceeded]] method)
@@ -55,23 +58,29 @@ class HybridAnalyzer(
resolverGuard: ResolverGuard,
resolver: Resolver,
extendedResolutionChecks: Seq[LogicalPlan => Unit] = Seq.empty,
- checkSupportedSinglePassFeatures: Boolean = true)
+ exposeExplicitlyUnsupportedResolverFeature: Boolean = false)
extends SQLConfHelper {
private var singlePassResolutionDuration: Option[Long] = None
private var fixedPointResolutionDuration: Option[Long] = None
private val resolverRunner: ResolverRunner =
new ResolverRunner(resolver, extendedResolutionChecks)
+ private val sampleRateGenerator = new Random()
def apply(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = {
+ val passedResolvedGuard = resolverGuard.apply(plan)
val dualRun =
conf.getConf(SQLConf.ANALYZER_DUAL_RUN_LEGACY_AND_SINGLE_PASS_RESOLVER) &&
- checkResolverGuard(plan)
+ passedResolvedGuard && checkDualRunSampleRate()
withTrackedAnalyzerBridgeState(dualRun) {
if (dualRun) {
resolveInDualRun(plan, tracker)
} else if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED)) {
- resolveInSinglePass(plan)
+ resolveInSinglePass(plan, tracker)
+ } else if (passedResolvedGuard && conf.getConf(
+ SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED_TENTATIVELY
+ )) {
+ resolveInSinglePassTentatively(plan, tracker)
} else {
resolveInFixedPoint(plan, tracker)
}
@@ -136,7 +145,7 @@ class HybridAnalyzer(
var singlePassException: Option[Throwable] = None
val singlePassResult = try {
val (resolutionDuration, result) = recordDuration {
- Some(resolveInSinglePass(plan))
+ Some(resolveInSinglePass(plan, tracker))
}
singlePassResolutionDuration = Some(resolutionDuration)
result
@@ -160,7 +169,7 @@ class HybridAnalyzer(
case None =>
singlePassException match {
case Some(singlePassEx: ExplicitlyUnsupportedResolverFeature)
- if checkSupportedSinglePassFeatures =>
+ if !exposeExplicitlyUnsupportedResolverFeature =>
fixedPointResult.get
case Some(singlePassEx) =>
throw singlePassEx
@@ -175,12 +184,31 @@ class HybridAnalyzer(
}
}
+ /**
+ * Run the single-pass Analyzer, but fall back to the fixed-point if
+ * [[ExplicitlyUnsupportedResolverFeature]] is thrown.
+ */
+ private def resolveInSinglePassTentatively(
+ plan: LogicalPlan,
+ tracker: QueryPlanningTracker): LogicalPlan = {
+ try {
+ resolveInSinglePass(plan, tracker)
+ } catch {
+ case _: ExplicitlyUnsupportedResolverFeature =>
+ resolveInFixedPoint(plan, tracker)
+ }
+ }
+
/**
* This method is used to run the single-pass Analyzer which will return the resolved plan
* or throw an exception if the resolution fails. Both cases are handled in the caller method.
* */
- private def resolveInSinglePass(plan: LogicalPlan): LogicalPlan =
- resolverRunner.resolve(plan, AnalysisContext.get.getSinglePassResolverBridgeState)
+ private def resolveInSinglePass(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan =
+ resolverRunner.resolve(
+ plan = plan,
+ analyzerBridgeState = AnalysisContext.get.getSinglePassResolverBridgeState,
+ tracker = tracker
+ )
/**
* This method is used to run the legacy Analyzer which will return the resolved plan
@@ -194,6 +222,10 @@ class HybridAnalyzer(
resolvedPlan
}
+ private def checkDualRunSampleRate(): Boolean = {
+ sampleRateGenerator.nextDouble() < conf.getConf(SQLConf.ANALYZER_DUAL_RUN_SAMPLE_RATE)
+ }
+
private def validateLogicalPlans(fixedPointResult: LogicalPlan, singlePassResult: LogicalPlan) = {
if (fixedPointResult.schema != singlePassResult.schema) {
throw QueryCompilationErrors.hybridAnalyzerOutputSchemaComparisonMismatch(
@@ -213,9 +245,6 @@ class HybridAnalyzer(
NormalizePlan(plan)
}
- private def checkResolverGuard(plan: LogicalPlan): Boolean =
- !checkSupportedSinglePassFeatures || resolverGuard.apply(plan)
-
private def recordDuration[T](thunk: => T): (Long, T) = {
val start = System.nanoTime()
val res = thunk
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierAndCteSubstituor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierAndCteSubstituor.scala
new file mode 100644
index 0000000000000..b527162f5ba2e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierAndCteSubstituor.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
+import org.apache.spark.sql.catalyst.plans.logical.{
+ CTERelationDef,
+ LogicalPlan,
+ SubqueryAlias,
+ UnresolvedWith
+}
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION, UNRESOLVED_RELATION}
+
+/**
+ * The [[IdentifierAndCteSubstitutor]] is responsible for substituting the IDENTIFIERs (not yet
+ * implemented) and CTE references in the unresolved logical plan before the actual resolution
+ * starts (specifically before metadata resolution). This is important for SQL features like WITH
+ * (that could confuse [[MetadataResolver]] with extra [[UnresolvedRelation]]s)
+ * or IDENTIFIER (that "hides" the actual [[UnresolvedRelation]]s).
+ *
+ * We only recurse into the plan if [[IdentifierAndCteSubstitutor.NODES_OF_INTEREST]] are present.
+ * This is done so that [[IdentifierAndCteSubstitutor]] is fast and not invasive.
+ */
+class IdentifierAndCteSubstitutor {
+ private var cteRegistry = new CteRegistry
+
+ /**
+ * This is the main entry point to the substitution process. It takes `unresolvedPlan` and
+ * substitutes CTEs and IDENTIFIERs (not yet implemented) that would otherwise confuse the
+ * metadata resolution process.
+ *
+ * [[CteRegistry]] has to be reset for each new invocation, because CTEs in views are analyzed
+ * in isolation.
+ */
+ def substitutePlan(unresolvedPlan: LogicalPlan): LogicalPlan = {
+ cteRegistry = new CteRegistry
+
+ substitute(unresolvedPlan)
+ }
+
+ private def substitute(unresolvedPlan: LogicalPlan): LogicalPlan = {
+ unresolvedPlan match {
+ case unresolvedWith: UnresolvedWith =>
+ handleWith(unresolvedWith)
+ case unresolvedRelation: UnresolvedRelation =>
+ handleUnresolvedRelation(unresolvedRelation)
+ case _ =>
+ handleOperator(unresolvedPlan)
+ }
+ }
+
+ /**
+ * Handle [[UnresolvedWith]] operator. WITH clause produces unresolved relations in the plan,
+ * that are not actual tables or views, but are just potential [[CTERelationRef]]s. To correctly
+ * detect those CTE references we use the [[CteRegistry]] framework. The actual CTE resolution
+ * is left to the main algebraic pass in the [[Resolver]] - here we replace the
+ * [[UnresolvedRelation]]s with [[UnresolvedCteRelationRef]] to avoid isuing useless catalog
+ * RPCs later on in the [[MetadataResolver]].
+ *
+ * We need to use the [[CteRegistry]] framework, because whether the relation is a CTE reference
+ * or not depends on its position in the logical plan tree:
+ * {{{
+ * CREATE TABLE rel2 (col1 INT);
+ *
+ * WITH rel1 AS (
+ * WITH rel2 AS (
+ * SELECT 1
+ * )
+ * SELECT * FROM rel2 -- `rel2` is a CTE reference
+ * )
+ * SELECT * FROM rel2 -- `rel2` is a table reference
+ * }}}
+ */
+ private def handleWith(unresolvedWith: UnresolvedWith): LogicalPlan = {
+ val cteRelationsAfterSubstitution = unresolvedWith.cteRelations.map { cteRelation =>
+ val (cteName, ctePlan) = cteRelation
+
+ val ctePlanAfter = cteRegistry.withNewScope() {
+ substitute(ctePlan).asInstanceOf[SubqueryAlias]
+ }
+
+ cteRegistry.currentScope.registerCte(cteName, CTERelationDef(ctePlanAfter))
+
+ (cteName, ctePlanAfter)
+ }
+
+ val childAfterSubstitution = cteRegistry.withNewScope() {
+ substitute(unresolvedWith.child)
+ }
+
+ val result = withOrigin(unresolvedWith.origin) {
+ unresolvedWith.copy(
+ child = childAfterSubstitution,
+ cteRelations = cteRelationsAfterSubstitution
+ )
+ }
+ result.copyTagsFrom(unresolvedWith)
+ result
+ }
+
+ /**
+ * Handle [[UnresolvedRelation]] operator, which could be a CTE reference. If that's the case, we
+ * replace it with [[UnresolvedCteRelationRef]].
+ */
+ private def handleUnresolvedRelation(unresolvedRelation: UnresolvedRelation): LogicalPlan = {
+ if (unresolvedRelation.multipartIdentifier.size == 1) {
+ cteRegistry.resolveCteName(unresolvedRelation.multipartIdentifier.head) match {
+ case Some(_) =>
+ val result = withOrigin(unresolvedRelation.origin) {
+ UnresolvedCteRelationRef(unresolvedRelation.multipartIdentifier.head)
+ }
+ result.copyTagsFrom(unresolvedRelation)
+ result
+ case None =>
+ unresolvedRelation
+ }
+ } else {
+ unresolvedRelation
+ }
+ }
+
+ /**
+ * Handle `unresolvedOperator` generically. We use the [[CteRegistry]] framework to recurse into
+ * its children and subquery expressions.
+ */
+ private def handleOperator(unresolvedOperator: LogicalPlan): LogicalPlan = {
+ val operatorAfterSubstitution = unresolvedOperator match {
+ case operator
+ if !operator.containsAnyPattern(IdentifierAndCteSubstitutor.NODES_OF_INTEREST: _*) =>
+ operator
+
+ case operator if operator.children.size > 1 =>
+ val newChildren = operator.children.map { child =>
+ cteRegistry.withNewScopeUnderMultiChildOperator(operator, child) {
+ substitute(child)
+ }
+ }
+ operator.withNewChildren(newChildren)
+
+ case operator if operator.children.size == 1 =>
+ val newChildren = Seq(substitute(operator.children.head))
+ operator.withNewChildren(newChildren)
+
+ case operator =>
+ operator
+ }
+
+ operatorAfterSubstitution.transformExpressionsWithPruning(
+ _.containsPattern(PLAN_EXPRESSION)
+ ) {
+ case subqueryExpression: SubqueryExpression =>
+ val newPlan = cteRegistry.withNewScope(isRoot = true) {
+ substitute(subqueryExpression.plan)
+ }
+
+ val result = withOrigin(subqueryExpression.origin) {
+ subqueryExpression.withNewPlan(newPlan)
+ }
+ result.copyTagsFrom(subqueryExpression)
+ result
+ }
+ }
+}
+
+object IdentifierAndCteSubstitutor {
+ val NODES_OF_INTEREST = Seq(CTE, UNRESOLVED_RELATION)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala
new file mode 100644
index 0000000000000..b7a5a733159fe
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.analysis.NaturalAndUsingJoinResolution
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.{JoinType, NaturalJoin, UsingJoin}
+import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.util._
+
+/**
+ * Resolves [[Join]] operator by resolving its left and right children and its join condition. If
+ * the unresolved join is [[NaturalJoin]] or [[UsingJoin]], the resulting operator will be
+ * [[Project]], otherwise it will be [[Join]].
+ */
+class JoinResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
+ extends TreeNodeResolver[Join, LogicalPlan] {
+ private val scopes = resolver.getNameScopes
+ private val expressionIdAssigner = expressionResolver.getExpressionIdAssigner
+ private val cteRegistry = resolver.getCteRegistry
+
+ /**
+ * Resolves [[Join]] operator:
+ * - Retrieve old output and child outputs if the operator is already resolved. This is relevant
+ * for partially resolved subtrees from DataFrame programs. Do not regenerate ExprIds if there
+ * are no conflicting ids.
+ * - Resolve each child in the context of a) New [[NameScope]] b) New [[ExpressionIdAssigner]]
+ * mapping. Collect children name scopes to use in [[Join]] output computation.
+ * - Based on the type of [[Join]] (natural, using or other) perform additional transformations
+ * and resolve join condition.
+ * - Return the resulting [[Project]] or [[Join]] with new children optionally wrapped in
+ * [[WithCTE]]. See [[CteScope]] scaladoc for more info.
+ */
+ override def resolve(unresolvedJoin: Join): LogicalPlan = {
+ val (resolvedLeftOperator: LogicalPlan, leftNameScope: NameScope) = resolveJoinChild(
+ unresolvedJoin = unresolvedJoin,
+ child = unresolvedJoin.left
+ )
+
+ val (resolvedRightOperator: LogicalPlan, rightNameScope: NameScope) = resolveJoinChild(
+ unresolvedJoin = unresolvedJoin,
+ child = unresolvedJoin.right
+ )
+
+ ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds(
+ Seq(leftNameScope.output, rightNameScope.output)
+ )
+
+ expressionIdAssigner.createMappingFromChildMappings()
+
+ val partiallyResolvedJoin = unresolvedJoin.copy(
+ left = resolvedLeftOperator,
+ right = resolvedRightOperator
+ )
+
+ handleDifferentTypesOfJoin(
+ unresolvedJoin = unresolvedJoin,
+ partiallyResolvedJoin = partiallyResolvedJoin,
+ leftNameScope = leftNameScope,
+ rightNameScope = rightNameScope
+ )
+ }
+
+ private def resolveJoinChild(
+ unresolvedJoin: Join,
+ child: LogicalPlan): (LogicalPlan, NameScope) = {
+ scopes.withNewScope() {
+ expressionIdAssigner.withNewMapping(collectChildMapping = true) {
+ cteRegistry.withNewScopeUnderMultiChildOperator(
+ unresolvedOperator = unresolvedJoin,
+ unresolvedChild = child
+ ) {
+ val resolvedLeftOperator = resolver.resolve(child)
+ (resolvedLeftOperator, scopes.current)
+ }
+ }
+ }
+ }
+
+ /**
+ * If the type of join is [[NaturalJoin]] or [[UsingJoin]], perform additional transformations in
+ * [[commonNaturalJoinProcessing]]. Otherwise, overwrite current name scope output with the
+ * result of [[Join.computeOutput]].
+ */
+ private def handleDifferentTypesOfJoin(
+ unresolvedJoin: Join,
+ partiallyResolvedJoin: Join,
+ leftNameScope: NameScope,
+ rightNameScope: NameScope): LogicalPlan = partiallyResolvedJoin match {
+ case Join(left, right, UsingJoin(joinType, usingCols), _, hint) =>
+ commonNaturalJoinProcessing(
+ unresolvedJoin = unresolvedJoin,
+ left = left,
+ leftNameScope = leftNameScope,
+ right = right,
+ rightNameScope = rightNameScope,
+ joinType = joinType,
+ joinNames = usingCols,
+ condition = None,
+ hint = hint
+ )
+ case Join(left, right, NaturalJoin(joinType), condition, hint) =>
+ val joinNames = getJoinNamesForNaturalJoin(leftNameScope, rightNameScope)
+ commonNaturalJoinProcessing(
+ unresolvedJoin = unresolvedJoin,
+ left = left,
+ leftNameScope = leftNameScope,
+ right = right,
+ rightNameScope = rightNameScope,
+ joinType = joinType,
+ joinNames = joinNames,
+ condition = condition,
+ hint = hint
+ )
+ case partiallyResolvedJoin: Join =>
+ handleRegularJoin(
+ unresolvedJoin = unresolvedJoin,
+ partiallyResolvedJoin = partiallyResolvedJoin,
+ leftNameScope = leftNameScope,
+ rightNameScope = rightNameScope
+ )
+ }
+
+ /**
+ * This method handles [[NaturalJoin]] and [[UsingJoin]] by computing their correct outputs and
+ * placing a [[Project]] node on top of them.
+ * The order of necessary operations is as follows:
+ * - Compute output list, hidden list and new condition with join pairs, if there are any.
+ * [[NaturalAndUsingJoinResolution.computeJoinOutputsAndNewCondition]] introduces new
+ * aliased expressions (e.g. [[Coalesce]] for keys), so we need to run the output list through
+ * [[ExpressionIdAssigner.mapExpression]].
+ * - Resolve the new condition.
+ * - Compute new `hiddenOutput` by appending elements from computed hidden list that are not
+ * already in current `hiddenOutput`. Hidden list must be qualified access only.
+ * - Overwrite current name scope with output list and newly computed hidden output.
+ * - Finally, put a [[Project]] node on top of the original [[Join]] by:
+ * - New project list becomes output list.
+ * - If [[Join]] was not a top level operator, append current hidden output to the project
+ * list.
+ * - Add new hidden output as a tag to project node in order to stay compatible with
+ * fixed-point. This should never be used in single-pass, but it can happen that fixed-point
+ * uses the single-pass result, therefore we need to set the tag.
+ */
+ private def commonNaturalJoinProcessing(
+ unresolvedJoin: Join,
+ left: LogicalPlan,
+ leftNameScope: NameScope,
+ right: LogicalPlan,
+ rightNameScope: NameScope,
+ joinType: JoinType,
+ joinNames: Seq[String],
+ condition: Option[Expression],
+ hint: JoinHint): LogicalPlan = {
+ val (outputList, hiddenList, newCondition) =
+ NaturalAndUsingJoinResolution.computeJoinOutputsAndNewCondition(
+ left = left,
+ leftOutput = leftNameScope.output,
+ right = right,
+ rightOutput = rightNameScope.output,
+ joinType = joinType,
+ joinNames = joinNames,
+ condition = condition,
+ resolveName = conf.resolver
+ )
+
+ val newOutputList = outputList.map(expressionIdAssigner.mapExpression)
+
+ val resolvedCondition =
+ resolveJoinCondition(unresolvedJoin, newCondition, leftNameScope, rightNameScope)
+
+ val hiddenListWithQualifiedAccess = hiddenList.map(_.markAsQualifiedAccessOnly())
+
+ val newHiddenOutput = hiddenListWithQualifiedAccess ++ scopes.current.hiddenOutput
+
+ scopes.overwriteCurrent(
+ output = Some(newOutputList.map(_.toAttribute)),
+ hiddenOutput = Some(newHiddenOutput)
+ )
+
+ val newProjectList =
+ if (unresolvedJoin.getTagValue(Resolver.TOP_LEVEL_OPERATOR).isEmpty) {
+ newOutputList ++ scopes.current.hiddenOutput
+ .filter(attribute => attribute.qualifiedAccessOnly)
+ } else {
+ newOutputList
+ }
+
+ val project = Project(newProjectList, Join(left, right, joinType, resolvedCondition, hint))
+
+ project.setTagValue(Project.hiddenOutputTag, newHiddenOutput)
+
+ project
+ }
+
+ /**
+ * Resolve a join that is not [[NaturalJoin]] or [[UsingJoin]]. In order to resolve the join we
+ * do the following:
+ * - Resolve join condition.
+ * - Overwrite [[NameScope.output]] with join's output.
+ * - Wrap the [[Join]] in [[WithCTE]] if necessary.
+ */
+ private def handleRegularJoin(
+ unresolvedJoin: Join,
+ partiallyResolvedJoin: Join,
+ leftNameScope: NameScope,
+ rightNameScope: NameScope) = {
+ val resolvedCondition = resolveJoinCondition(
+ unresolvedJoin = unresolvedJoin,
+ unresolvedCondition = partiallyResolvedJoin.condition,
+ leftNameScope = leftNameScope,
+ rightNameScope = rightNameScope
+ )
+
+ scopes.overwriteCurrent(
+ output = Some(
+ Join.computeOutput(
+ partiallyResolvedJoin.joinType,
+ leftNameScope.output,
+ rightNameScope.output
+ )
+ ),
+ hiddenOutput = Some(leftNameScope.hiddenOutput ++ rightNameScope.hiddenOutput)
+ )
+
+ val resolvedJoin = partiallyResolvedJoin.copy(condition = resolvedCondition)
+
+ cteRegistry.currentScope.tryPutWithCTE(
+ unresolvedOperator = unresolvedJoin,
+ resolvedOperator = resolvedJoin
+ )
+ }
+
+ /**
+ * Computes the intersection of two child name scopes, by name.
+ */
+ private def getJoinNamesForNaturalJoin(
+ leftNameScope: NameScope,
+ rightNameScope: NameScope): Seq[String] = {
+ leftNameScope.output
+ .flatMap(attribute => rightNameScope.findAttributesByName(attribute.name))
+ .map(_.name)
+ }
+
+ /**
+ * Resolves join condition by __all__ attributes from child scopes. We need to overwrite current
+ * scope first to prepare for [[resolveExpressionTreeInOperator]]. [[Join]] will actually produce
+ * different output than the one we are setting here, so additional overwrite with correct values
+ * will be needed. Two overwrites are necessary because condition is resolved from original
+ * children outputs, whereas output of [[Join]] will either not contain all attributes or their
+ * nullabilities will be different.
+ */
+ private def resolveJoinCondition(
+ unresolvedJoin: Join,
+ unresolvedCondition: Option[Expression],
+ leftNameScope: NameScope,
+ rightNameScope: NameScope) = {
+ scopes.overwriteCurrent(
+ output = Some(leftNameScope.output ++ rightNameScope.output),
+ hiddenOutput = Some(leftNameScope.hiddenOutput ++ rightNameScope.hiddenOutput)
+ )
+
+ unresolvedCondition.map { condition =>
+ expressionResolver.resolveExpressionTreeInOperator(
+ condition,
+ unresolvedJoin
+ )
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitLikeExpressionValidator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitLikeExpressionValidator.scala
new file mode 100644
index 0000000000000..b35516fce3142
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitLikeExpressionValidator.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.types.IntegerType
+
+/**
+ * The [[LimitLikeExpressionValidator]] validates [[LocalLimit]], [[GlobalLimit]], [[Offset]] or
+ * [[Tail]] integer expressions.
+ */
+class LimitLikeExpressionValidator extends QueryErrorsBase {
+ def validateLimitLikeExpr(
+ limitLikeExpression: Expression,
+ partiallyResolvedLimitLike: LogicalPlan
+ ): Expression = {
+ val evaluatedExpression =
+ evaluateLimitLikeExpression(limitLikeExpression, partiallyResolvedLimitLike)
+
+ checkValidLimitWithOffset(evaluatedExpression, partiallyResolvedLimitLike)
+
+ limitLikeExpression
+ }
+
+ private def checkValidLimitWithOffset(evaluatedExpression: Int, plan: LogicalPlan): Unit = {
+ plan match {
+ case LocalLimit(_, Offset(offsetExpr, _)) =>
+ val offset = offsetExpr.eval().asInstanceOf[Int]
+ if (Int.MaxValue - evaluatedExpression < offset) {
+ throw throwInvalidLimitWithOffsetSumExceedsMaxInt(evaluatedExpression, offset, plan)
+ }
+ case _ =>
+ }
+ }
+
+ /**
+ * Evaluate a resolved limit expression of [[GlobalLimit]], [[LocalLimit]], [[Offset]] or
+ * [[Tail]], while performing required checks:
+ * - The expression has to be foldable
+ * - The result data type has to be [[IntegerType]]
+ * - The evaluated expression has to be non-null
+ * - The evaluated expression has to be positive
+ *
+ * The `foldable` check is implemented in some expressions
+ * as a recursive expression tree traversal.
+ * It is not an ideal approach for the single-pass [[ExpressionResolver]],
+ * but __is__ practical, since:
+ * - We have to call `eval` here anyway, and it's recursive
+ * - In practice `LIMIT`, `OFFSET` and `TAIL` expression trees are very small
+ *
+ * The return type of evaluation is Int, as we perform check that the expression has
+ * IntegerType.
+ */
+ private def evaluateLimitLikeExpression(expression: Expression, plan: LogicalPlan): Int = {
+ val operatorName = plan match {
+ case _: Offset => "offset"
+ case _: Tail => "tail"
+ case _: LocalLimit | _: GlobalLimit => "limit"
+ case other =>
+ throw SparkException.internalError(
+ s"Unexpected limit like operator type: ${other.getClass.getName}"
+ )
+ }
+ if (!expression.foldable) {
+ throwInvalidLimitLikeExpressionIsUnfoldable(operatorName, expression)
+ }
+ if (expression.dataType != IntegerType) {
+ throwInvalidLimitLikeExpressionDataType(operatorName, expression)
+ }
+ expression.eval() match {
+ case null =>
+ throwInvalidLimitLikeExpressionIsNull(operatorName, expression)
+ case value: Int if value < 0 =>
+ throwInvalidLimitLikeExpressionIsNegative(operatorName, expression, value)
+ case result =>
+ result.asInstanceOf[Int]
+ }
+ }
+
+ private def throwInvalidLimitWithOffsetSumExceedsMaxInt(
+ limit: Int,
+ offset: Int,
+ plan: LogicalPlan): Nothing =
+ throw new AnalysisException(
+ errorClass = "SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT",
+ messageParameters = Map("limit" -> limit.toString, "offset" -> offset.toString),
+ origin = plan.origin
+ )
+
+ private def throwInvalidLimitLikeExpressionIsUnfoldable(
+ name: String,
+ expression: Expression): Nothing =
+ throw new AnalysisException(
+ errorClass = "INVALID_LIMIT_LIKE_EXPRESSION.IS_UNFOLDABLE",
+ messageParameters = Map("name" -> name, "expr" -> toSQLExpr(expression)),
+ origin = expression.origin
+ )
+
+ private def throwInvalidLimitLikeExpressionDataType(
+ name: String,
+ expression: Expression): Nothing =
+ throw new AnalysisException(
+ errorClass = "INVALID_LIMIT_LIKE_EXPRESSION.DATA_TYPE",
+ messageParameters = Map(
+ "name" -> name,
+ "expr" -> toSQLExpr(expression),
+ "dataType" -> toSQLType(expression.dataType)
+ ),
+ origin = expression.origin
+ )
+
+ private def throwInvalidLimitLikeExpressionIsNull(name: String, expression: Expression): Nothing =
+ throw new AnalysisException(
+ errorClass = "INVALID_LIMIT_LIKE_EXPRESSION.IS_NULL",
+ messageParameters = Map(
+ "name" -> name,
+ "expr" -> toSQLExpr(expression)
+ ),
+ origin = expression.origin
+ )
+
+ private def throwInvalidLimitLikeExpressionIsNegative(
+ name: String,
+ expression: Expression,
+ value: Int): Nothing =
+ throw new AnalysisException(
+ errorClass = "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE",
+ messageParameters =
+ Map("name" -> name, "expr" -> toSQLExpr(expression), "v" -> toSQLValue(value, IntegerType)),
+ origin = expression.origin
+ )
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala
index 0d57eae0be7a1..b54824c348194 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala
@@ -20,8 +20,13 @@ package org.apache.spark.sql.catalyst.analysis.resolver
import java.util.ArrayDeque
import org.apache.spark.sql.catalyst.analysis.{RelationResolution, UnresolvedRelation}
-import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression}
-import org.apache.spark.sql.catalyst.plans.logical.{AnalysisHelper, LogicalPlan}
+import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
+import org.apache.spark.sql.catalyst.plans.logical.{
+ AnalysisHelper,
+ LogicalPlan,
+ SubqueryAlias,
+ UnresolvedWith
+}
import org.apache.spark.sql.connector.catalog.CatalogManager
/**
@@ -54,9 +59,8 @@ class MetadataResolver(
* calls for the [[UnresolvedRelation]]s present in that tree. During the `unresolvedPlan`
* traversal we fill [[relationsWithResolvedMetadata]] with resolved metadata by relation id.
* This map will be used to resolve the plan in single-pass by the [[Resolver]] using
- * [[getRelationWithResolvedMetadata]]. If the generic metadata resolution using
- * [[RelationResolution]] wasn't successful, we resort to using [[extensions]].
- * Otherwise, we fail with an exception.
+ * [[getRelationWithResolvedMetadata]]. We always try to complete the default resolution using
+ * extensions.
*/
override def resolve(unresolvedPlan: LogicalPlan): Unit = {
traverseLogicalPlanTree(unresolvedPlan) {
@@ -64,20 +68,26 @@ class MetadataResolver(
val relationId = relationIdFromUnresolvedRelation(unresolvedRelation)
if (!relationsWithResolvedMetadata.containsKey(relationId)) {
- val relationWithResolvedMetadata = resolveRelation(unresolvedRelation).orElse {
- // In case the generic metadata resolution returned `None`, we try to check if any
- // of the [[extensions]] matches this `unresolvedRelation`, and resolve it using
- // that extension.
- tryDelegateResolutionToExtension(unresolvedRelation, prohibitedResolver)
+ val relationAfterDefaultResolution =
+ resolveRelation(unresolvedRelation).getOrElse(unresolvedRelation)
+
+ val relationAfterExtensionResolution = relationAfterDefaultResolution match {
+ case subqueryAlias: SubqueryAlias =>
+ tryDelegateResolutionToExtension(subqueryAlias.child, prohibitedResolver).map {
+ relation =>
+ subqueryAlias.copy(child = relation)
+ }
+ case _ =>
+ tryDelegateResolutionToExtension(relationAfterDefaultResolution, prohibitedResolver)
}
- relationWithResolvedMetadata match {
- case Some(relationWithResolvedMetadata) =>
+ relationAfterExtensionResolution.getOrElse(relationAfterDefaultResolution) match {
+ case _: UnresolvedRelation =>
+ case relationWithResolvedMetadata =>
relationsWithResolvedMetadata.put(
relationId,
relationWithResolvedMetadata
)
- case None =>
}
}
case _ =>
@@ -109,18 +119,21 @@ class MetadataResolver(
case Left(logicalPlan) =>
visitor(logicalPlan)
- for (child <- logicalPlan.children) {
- stack.push(Left(child))
- }
- for (innerChild <- logicalPlan.innerChildren) {
- innerChild match {
- case plan: LogicalPlan =>
- stack.push(Left(plan))
- case _ =>
- }
- }
- for (expression <- logicalPlan.expressions) {
- stack.push(Right(expression))
+ logicalPlan match {
+ case unresolvedWith: UnresolvedWith =>
+ for (cteRelation <- unresolvedWith.cteRelations) {
+ stack.push(Left(cteRelation._2))
+ }
+
+ stack.push(Left(unresolvedWith.child))
+ case _ =>
+ for (child <- logicalPlan.children) {
+ stack.push(Left(child))
+ }
+
+ for (expression <- logicalPlan.expressions) {
+ stack.push(Right(expression))
+ }
}
case Right(expression) =>
for (child <- expression.children) {
@@ -128,12 +141,8 @@ class MetadataResolver(
}
expression match {
- case planExpression: PlanExpression[_] =>
- planExpression.plan match {
- case plan: LogicalPlan =>
- stack.push(Left(plan))
- case _ =>
- }
+ case subqueryExpression: SubqueryExpression =>
+ stack.push(Left(subqueryExpression.plan))
case _ =>
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala
index 5e6aa65830406..86fc43fd52243 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala
@@ -17,14 +17,30 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import java.util.{ArrayDeque, HashSet}
+import java.util.{ArrayDeque, HashMap, LinkedHashMap}
import scala.collection.mutable
+import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.SQLConfHelper
-import org.apache.spark.sql.catalyst.analysis.{Resolver => NameComparator, UnresolvedStar}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSeq, ExprId, NamedExpression}
+import org.apache.spark.sql.catalyst.analysis.{
+ LiteralFunctionResolution,
+ Resolver => NameComparator,
+ UnresolvedStar
+}
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ Attribute,
+ AttributeSeq,
+ Expression,
+ ExprId,
+ ExtractValue,
+ NamedExpression,
+ OuterReference
+}
+import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.Metadata
/**
* The [[NameScope]] is used to control the resolution of names (table, column, alias identifiers).
@@ -32,8 +48,8 @@ import org.apache.spark.sql.internal.SQLConf
* program operators.
*
* The [[NameScope]] output is immutable. If it's necessary to update the output,
- * [[NameScopeStack]] methods are used ([[overwriteTop]] or [[withNewScope]]). The [[NameScope]]
- * is always used through the [[NameScopeStack]].
+ * [[NameScopeStack]] methods are used ([[overwriteCurrent]] or [[withNewScope]]). The
+ * [[NameScope]] is always used through the [[NameScopeStack]].
*
* The resolution of identifiers is case-insensitive.
*
@@ -41,16 +57,35 @@ import org.apache.spark.sql.internal.SQLConf
*
* 1. Resolution of local references:
* - column reference
+ * - parameterless function reference
* - struct field or map key reference
* 2. Resolution of lateral column aliases (if enabled).
+ * 3. In the context of [[Aggregate]]: resolution of names in groping expressions list referencing
+ * aliases in aggregate expressions.
*
- * For example, in a query like:
+ * Following examples showcase the priority of name resolution:
*
* {{{ SELECT 1 AS col1, col1 FROM VALUES (2) }}}
*
* Because column resolution has a higher priority than LCA resolution, the result will be [1, 2]
* and not [1, 1].
*
+ * {{{
+ * CREATE TABLE t AS SELECT col1 as current_date FROM VALUES (2);
+ *
+ * SELECT
+ * 1 AS current_timestamp,
+ * current_timestamp,
+ * current_date
+ * FROM
+ * foo;
+ * }}}
+ *
+ * Result of the previous SELECT will be: [1, 2025-02-13T07:55:26.206+00:00, 2]. As can be seen,
+ * because of resolution precedence, current_date is resolved as a table column, but
+ * current_timestamp is resolved as a function without parenthesis instead of a lateral column
+ * reference.
+ *
* Approximate tree of [[NameScope]] manipulations is shown in the following example:
*
* {{{
@@ -71,33 +106,50 @@ import org.apache.spark.sql.internal.SQLConf
* unionAttributes = withNewScope {
* lhsOutput = withNewScope {
* expandedStar = withNewScope {
- * scope.overwriteTop(localRelation.output)
+ * scopes.overwriteCurrent(localRelation.output)
* scope.expandStar(star)
* }
- * scope.overwriteTop(expandedStar)
+ * scopes.overwriteCurrent(expandedStar)
* scope.output
* }
* rhsOutput = withNewScope {
* subqueryAttributes = withNewScope {
- * scope.overwriteTop(t1.output)
- * scope.overwriteTop(prependQualifier(scope.output, "t2"))
+ * scopes.overwriteCurrent(t1.output)
+ * scopes.overwriteCurrent(prependQualifier(scope.output, "t2"))
* [scope.matchMultiPartName("t2", "col1"), scope.matchMultiPartName("t2", "col2")]
* }
- * scope.overwriteTop(subqueryAttributes)
+ * scopes.overwriteCurrent(subqueryAttributes)
* scope.output
* }
- * scope.overwriteTop(coerce(lhsOutput, rhsOutput))
+ * scopes.overwriteCurrent(coerce(lhsOutput, rhsOutput))
* [scope.matchMultiPartName("col1"), alias(scope.matchMultiPartName("col2"), "alias1")]
* }
- * scope.overwriteTop(unionAttributes)
+ * scopes.overwriteCurrent(unionAttributes)
* }}}
*
* @param output These are the attributes visible for lookups in the current scope.
* These may be:
* - Transformed outputs of lower scopes (e.g. type-coerced outputs of [[Union]]'s children).
* - Output of a current operator that is being resolved (leaf nodes like [[Relations]]).
+ * @param hiddenOutput Attributes that are not directly visible in the scope, but available for
+ * lookup in case the resolved attribute is not found in `output`.
+ * @param isSubqueryRoot Indicates that the current scope is a root of a subquery. This is used by
+ * [[NameScopeStack.resolveMultipartName]] to detect the nearest outer scope.
*/
-class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
+class NameScope(
+ val output: Seq[Attribute] = Seq.empty,
+ val hiddenOutput: Seq[Attribute] = Seq.empty,
+ val isSubqueryRoot: Boolean = false)
+ extends SQLConfHelper {
+
+ /**
+ * This is an internal class used to store resolved multipart name, with correct precedence as
+ * specified by [[NameScope]] class doc.
+ */
+ private case class ResolvedMultipartName(
+ candidates: Seq[Expression],
+ referencedAttribute: Option[Attribute],
+ aliasMetadata: Option[Metadata] = None)
/**
* [[nameComparator]] is a function that is used to compare two identifiers. Its implementation
@@ -107,10 +159,28 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
private val nameComparator: NameComparator = conf.resolver
/**
- * [[attributesForResolution]] is an [[AttributeSeq]] that is used for resolution of
- * multipart attribute names. It's created from the `attributes` when [[NameScope]] is updated.
+ * [[attributesForResolution]] is an [[AttributeSeq]] that is used for resolution of multipart
+ * attribute names, by output. It's created from the `output` when
+ * [[NameScope]] is updated.
+ */
+ private val attributesForResolution: AttributeSeq =
+ AttributeSeq.fromNormalOutput(output)
+
+ /**
+ * [[hiddenAttributesForResolution]] is an [[AttributeSeq]] that is used for resolution of
+ * multipart attribute names, by hidden output. It's created from the `hiddenOutput` when
+ * [[NameScope]] is updated.
+ */
+ private lazy val hiddenAttributesForResolution: AttributeSeq =
+ AttributeSeq.fromNormalOutput(hiddenOutput)
+
+ /**
+ * [[metadataAttributesForResolution]] is an [[AttributeSeq]] that is used for resolution of
+ * multipart attribute names, by qualified access only columns from hidden output. It's created
+ * from the `hiddenOutput` when [[NameScope]] is updated.
*/
- private val attributesForResolution: AttributeSeq = AttributeSeq.fromNormalOutput(output)
+ private lazy val metadataAttributesForResolution: AttributeSeq =
+ AttributeSeq.fromNormalOutput(hiddenOutput.filter(_.qualifiedAccessOnly))
/**
* [[attributesByName]] is used to look up attributes by one-part name from the operator's output.
@@ -120,15 +190,73 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
private lazy val attributesByName = createAttributesByName(output)
/**
- * Expression IDs from `output`. See [[hasAttributeWithId]] for more details.
+ * Returns a map of [[ExprId]] to [[Attribute]] from `output`. See [[getAttributeById]] for
+ * more details.
+ */
+ private lazy val attributesById: HashMap[ExprId, Attribute] = createAttributeIds(output)
+
+ /**
+ * Returns a map of [[ExprId]] to [[Attribute]] from `hiddenOutput`.
*/
- private lazy val attributeIds = createAttributeIds(output)
+ private lazy val hiddenAttributesById: HashMap[ExprId, Attribute] =
+ createAttributeIds(hiddenOutput)
- private val isLcaEnabled = conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)
- lazy val lcaRegistry: LateralColumnAliasRegistry = if (isLcaEnabled) {
- new LateralColumnAliasRegistryImpl(output)
- } else {
- new LateralColumnAliasProhibitedRegistry
+ lazy val lcaRegistry: LateralColumnAliasRegistry =
+ if (conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
+ new LateralColumnAliasRegistryImpl(output)
+ } else {
+ new LateralColumnAliasProhibitedRegistry
+ }
+
+ /**
+ * All aliased aggregate expressions from an [[Aggregate]] that is currently being resolved.
+ * Used in [[resolveMultipartName]] to resolve names in grouping expressions list referencing
+ * aggregate expressions.
+ */
+ private lazy val topAggregateExpressionsByAliasName: IdentifierMap[Alias] =
+ new IdentifierMap[Alias]
+
+ /**
+ * Returns new [[NameScope]] which preserves all the immutable [[NameScope]] properties but
+ * overwrites `output` and `hiddenOutput` if provided. Mutable state like `lcaRegistry` is not
+ * preserved.
+ */
+ def overwriteOutput(
+ output: Option[Seq[Attribute]] = None,
+ hiddenOutput: Option[Seq[Attribute]] = None): NameScope = {
+ new NameScope(
+ output = output.getOrElse(this.output),
+ hiddenOutput = hiddenOutput.getOrElse(this.hiddenOutput),
+ isSubqueryRoot = isSubqueryRoot
+ )
+ }
+
+ /**
+ * Given referenced attributes, returns all attributes that are referenced and missing from
+ * current output, but can be found in hidden output.
+ */
+ def resolveMissingAttributesByHiddenOutput(
+ referencedAttributes: HashMap[ExprId, Attribute]): Seq[Attribute] = {
+ val distinctMissingAttributes = new LinkedHashMap[ExprId, Attribute]
+ hiddenOutput.foreach(
+ attribute =>
+ if (referencedAttributes.containsKey(attribute.exprId) &&
+ !attributesById.containsKey(attribute.exprId) &&
+ !distinctMissingAttributes.containsKey(attribute.exprId)) {
+ distinctMissingAttributes.put(attribute.exprId, attribute)
+ }
+ )
+ distinctMissingAttributes.asScala.values.toSeq
+ }
+
+ /**
+ * Add a top level alias to the map so it can be used when resolving a grouping expression.
+ */
+ def addTopAggregateExpression(aliasedAggregateExpression: Alias): Unit = {
+ topAggregateExpressionsByAliasName.put(
+ aliasedAggregateExpression.name,
+ aliasedAggregateExpression
+ )
}
/**
@@ -174,7 +302,7 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
def expandStar(unresolvedStar: UnresolvedStar): Seq[NamedExpression] = {
unresolvedStar.expandStar(
childOperatorOutput = output,
- childOperatorMetadataOutput = Seq.empty,
+ childOperatorMetadataOutput = hiddenOutput,
resolve =
(nameParts, nameComparator) => attributesForResolution.resolve(nameParts, nameComparator),
suggestedAttributes = output,
@@ -258,6 +386,23 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
* SELECT COL1 FROM t;
* }}}
*
+ * Name resolution can be done using the hidden output for certain operators (e.g [[Sort]],
+ * [[Filter]]). This is indicated by `canResolveNameByHiddenOutput` which is passed from
+ * [[ExpressionResolver.resolveAttribute]] based on the parent operator.
+ * Example:
+ *
+ * {{{
+ * -- Project's output = [`col1`]; Project's hidden output = [`col1`, `col2`]
+ * SELECT col1 FROM VALUES(1, 2) ORDER BY col2;
+ * }}}
+ *
+ * The names in [[Aggregate.groupingExpressions]] can reference
+ * [[Aggregate.aggregateExpressions]] aliases. `canReferenceAggregateExpressionAliases` will be
+ * true when we are resolving the grouping expressions.
+ * Example:
+ *
+ * {{ SELECT col1 + col2 AS a FROM VALUES (1, 2) GROUP BY a; }}}
+ *
* We are relying on the [[AttributeSeq]] to perform that work, since it requires complex
* resolution logic involving nested field extraction and multipart name matching.
*
@@ -265,36 +410,57 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
*/
def resolveMultipartName(
multipartName: Seq[String],
- canLaterallyReferenceColumn: Boolean = true): NameTarget = {
- val (candidates, nestedFields) =
- attributesForResolution.getCandidatesForResolution(multipartName, nameComparator)
+ canLaterallyReferenceColumn: Boolean = true,
+ canReferenceAggregateExpressionAliases: Boolean = false,
+ canResolveNameByHiddenOutput: Boolean = false): NameTarget = {
- val (candidatesWithLCAs: Seq[Attribute], referencedAttribute: Option[Attribute]) =
- if (candidates.isEmpty && canLaterallyReferenceColumn) {
- getLcaCandidates(multipartName)
- } else {
- (candidates, None)
- }
-
- val resolvedCandidates = attributesForResolution.resolveCandidates(
- multipartName,
- nameComparator,
- candidatesWithLCAs,
- nestedFields
- )
+ val resolvedMultipartName: ResolvedMultipartName =
+ tryResolveMultipartNameByOutput(
+ multipartName,
+ nameComparator,
+ attributesForResolution,
+ canResolveByProposedAttributes = true
+ ).orElse(
+ tryResolveMultipartNameByOutput(
+ multipartName,
+ nameComparator,
+ metadataAttributesForResolution,
+ canResolveByProposedAttributes = true
+ )
+ )
+ .orElse(
+ tryResolveMultipartNameByOutput(
+ multipartName,
+ nameComparator,
+ hiddenAttributesForResolution,
+ canResolveByProposedAttributes = canResolveNameByHiddenOutput
+ )
+ )
+ .orElse(tryResolveMultipartNameAsLiteralFunction(multipartName))
+ .orElse(
+ tryResolveMultipartNameAsLateralColumnReference(
+ multipartName,
+ canLaterallyReferenceColumn
+ )
+ )
+ .orElse(
+ tryResolveAttributeAsGroupByAlias(multipartName, canReferenceAggregateExpressionAliases)
+ )
+ .getOrElse(ResolvedMultipartName(candidates = Seq.empty, referencedAttribute = None))
- resolvedCandidates match {
+ resolvedMultipartName.candidates match {
case Seq(Alias(child, aliasName)) =>
NameTarget(
candidates = Seq(child),
aliasName = Some(aliasName),
- lateralAttributeReference = referencedAttribute,
+ aliasMetadata = resolvedMultipartName.aliasMetadata,
+ lateralAttributeReference = resolvedMultipartName.referencedAttribute,
output = output
)
case other =>
NameTarget(
candidates = other,
- lateralAttributeReference = referencedAttribute,
+ lateralAttributeReference = resolvedMultipartName.referencedAttribute,
output = output
)
}
@@ -323,11 +489,99 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
}
/**
- * Check if `output` contains attributes with `expressionId`. This is used to disable missing
- * attribute propagation for DataFrames, because we don't support it yet.
+ * Returns attribute with `expressionId` if `output` contains it. This is used to preserve
+ * nullability for resolved [[AttributeReference]].
*/
- def hasAttributeWithId(expressionId: ExprId): Boolean = {
- attributeIds.contains(expressionId)
+ def getAttributeById(expressionId: ExprId): Option[Attribute] =
+ Option(attributesById.get(expressionId))
+
+ /**
+ * Returns attribute with `expressionId` if `hiddenOutput` contains it.
+ */
+ def getHiddenAttributeById(expressionId: ExprId): Option[Attribute] =
+ Option(hiddenAttributesById.get(expressionId))
+
+ private def tryResolveMultipartNameByOutput(
+ multipartName: Seq[String],
+ nameComparator: NameComparator,
+ attributesForResolution: AttributeSeq,
+ canResolveByProposedAttributes: Boolean): Option[ResolvedMultipartName] = {
+ if (canResolveByProposedAttributes) {
+ val (candidates, nestedFields) =
+ attributesForResolution.getCandidatesForResolution(multipartName, nameComparator)
+ val resolvedCandidates = attributesForResolution.resolveCandidates(
+ multipartName,
+ nameComparator,
+ candidates,
+ nestedFields
+ )
+ if (resolvedCandidates.nonEmpty) {
+ Some(ResolvedMultipartName(candidates = resolvedCandidates, referencedAttribute = None))
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+
+ private def tryResolveMultipartNameAsLiteralFunction(
+ multipartName: Seq[String]): Option[ResolvedMultipartName] = {
+ val literalFunction = LiteralFunctionResolution.resolve(multipartName).toSeq
+ if (literalFunction.nonEmpty) {
+ Some(ResolvedMultipartName(candidates = literalFunction, referencedAttribute = None))
+ } else {
+ None
+ }
+ }
+
+ private def tryResolveMultipartNameAsLateralColumnReference(
+ multipartName: Seq[String],
+ canLaterallyReferenceColumn: Boolean): Option[ResolvedMultipartName] = {
+ val (candidatesForLca, nestedFields, referencedAttribute) =
+ if (canLaterallyReferenceColumn) {
+ getLcaCandidates(multipartName)
+ } else {
+ (Seq.empty, Seq.empty, None)
+ }
+
+ val resolvedCandidatesForLca = attributesForResolution.resolveCandidates(
+ multipartName,
+ nameComparator,
+ candidatesForLca,
+ nestedFields
+ )
+ if (resolvedCandidatesForLca.nonEmpty) {
+ Some(
+ ResolvedMultipartName(
+ candidates = resolvedCandidatesForLca,
+ referencedAttribute = referencedAttribute
+ )
+ )
+ } else {
+ None
+ }
+ }
+
+ private def tryResolveAttributeAsGroupByAlias(
+ multipartName: Seq[String],
+ canReferenceAggregateExpressionAliases: Boolean): Option[ResolvedMultipartName] = {
+ if (canReferenceAggregateExpressionAliases) {
+ topAggregateExpressionsByAliasName.get(multipartName.head) match {
+ case None =>
+ None
+ case Some(alias) =>
+ Some(
+ ResolvedMultipartName(
+ candidates = Seq(alias.child),
+ referencedAttribute = None,
+ aliasMetadata = Some(alias.metadata)
+ )
+ )
+ }
+ } else {
+ None
+ }
}
/**
@@ -335,15 +589,16 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
* columns. Here we do [[AttributeSeq.fromNormalOutput]] because a struct field can also be
* laterally referenced and we need to properly resolve [[GetStructField]] node.
*/
- private def getLcaCandidates(multipartName: Seq[String]): (Seq[Attribute], Option[Attribute]) = {
+ private def getLcaCandidates(
+ multipartName: Seq[String]): (Seq[Attribute], Seq[String], Option[Attribute]) = {
val referencedAttribute = lcaRegistry.getAttribute(multipartName.head)
if (referencedAttribute.isDefined) {
val attributesForResolution = AttributeSeq.fromNormalOutput(Seq(referencedAttribute.get))
- val (newCandidates, _) =
+ val (newCandidates, nestedFields) =
attributesForResolution.getCandidatesForResolution(multipartName, nameComparator)
- (newCandidates, Some(referencedAttribute.get))
+ (newCandidates, nestedFields, Some(referencedAttribute.get))
} else {
- (Seq.empty, None)
+ (Seq.empty, Seq.empty, None)
}
}
@@ -365,10 +620,10 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
result
}
- private def createAttributeIds(attributes: Seq[Attribute]): HashSet[ExprId] = {
- val result = new HashSet[ExprId]
+ private def createAttributeIds(attributes: Seq[Attribute]): HashMap[ExprId, Attribute] = {
+ val result = new HashMap[ExprId, Attribute](attributes.size)
for (attribute <- attributes) {
- result.add(attribute.exprId)
+ result.put(attribute.exprId, attribute)
}
result
@@ -376,35 +631,40 @@ class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper {
}
/**
- * The [[NameScopeStack]] is a stack of [[NameScope]]s managed by the [[Resolver]]. Usually a top
- * scope is used for name resolution, but in case of correlated subqueries we can lookup names in
- * the parent scopes. Low-level scope creation is managed internally, and only high-level api like
- * [[withNewScope]] is available to the resolvers. Freshly-created [[NameScopeStack]] contains an
- * empty root [[NameScope]], which in the context of [[Resolver]] corresponds to the query output.
+ * The [[NameScopeStack]] is a stack of [[NameScope]]s managed by the [[Resolver]]. Usually the
+ * current scope is used for name resolution, but in case of correlated subqueries we can lookup
+ * names in the parent scopes. Low-level scope creation is managed internally, and only high-level
+ * api like [[withNewScope]] is available to the resolvers. Freshly-created [[NameScopeStack]]
+ * contains an empty root [[NameScope]], which in the context of [[Resolver]] corresponds to the
+ * query output.
*/
class NameScopeStack extends SQLConfHelper {
private val stack = new ArrayDeque[NameScope]
stack.push(new NameScope)
/**
- * Get the top scope, which is a default choice for name resolution.
+ * Get the current scope, which is a default choice for name resolution.
*/
- def top: NameScope = {
+ def current: NameScope = {
stack.peek()
}
/**
- * Completely overwrite the top scope state with operator `output`.
+ * Completely overwrite the current scope state with operator `output` and `hiddenOutput`. If
+ * `hiddenOutput` is not provided, preserve the previous `hiddenOutput`. Additionally, update
+ * nullabilities of attributes in hidden output from new output, so that if attribute was
+ * nullable in either old hidden output or new output, it must stay nullable in new hidden
+ * output as well.
*
* This method is called by the [[Resolver]] when we've calculated the output of an operator that
* is being resolved. The new output is calculated based on the outputs of operator's children.
*
- * Example for [[SubqueryAlias]], here we rewrite the top [[NameScope]]'s attributes to prepend
- * subquery qualifier to their names:
+ * Example for [[SubqueryAlias]], here we rewrite the current [[NameScope]]'s attributes to
+ * prepend subquery qualifier to their names:
*
* {{{
* val qualifier = sa.identifier.qualifier :+ sa.alias
- * scope.overwriteTop(scope.output.map(attribute => attribute.withQualifier(qualifier)))
+ * scopes.overwriteCurrent(scope.output.map(attribute => attribute.withQualifier(qualifier)))
* }}}
*
* Trivially, we would call this method for every operator in the query plan,
@@ -413,10 +673,57 @@ class NameScopeStack extends SQLConfHelper {
*
* This method should be preferred over [[withNewScope]].
*/
- def overwriteTop(output: Seq[Attribute]): Unit = {
- val newScope = new NameScope(output)
+ def overwriteCurrent(
+ output: Option[Seq[Attribute]] = None,
+ hiddenOutput: Option[Seq[Attribute]] = None): Unit = {
+ val hiddenOutputWithUpdatedNullabilities = updateNullabilitiesInHiddenOutput(
+ output.getOrElse(stack.peek().output),
+ hiddenOutput.getOrElse(stack.peek().hiddenOutput)
+ )
+ val newScope = stack.pop.overwriteOutput(output, Some(hiddenOutputWithUpdatedNullabilities))
+
+ stack.push(newScope)
+ }
+
+ /**
+ * Overwrites output of the current [[NameScope]] entry and:
+ * 1. extends hidden output with the provided output (only attributes that are not in the hidden
+ * output are added). This is done because resolution of arguments can be done through certain
+ * operators by hidden output. This use case is specific to Dataframe programs. Example:
+ *
+ * {{{
+ * val df = (1 to 100).map { i => (i, i % 10, i % 2 == 0) }.toDF("a", "b", "c")
+ * df.select($"a", $"b").filter($"c")
+ * }}}
+ *
+ * Unresolved tree would be:
+ *
+ * Filter 'c
+ * +- 'Project ['a, 'b]
+ * +- Project [_1 AS a, _2 AS b, _3 AS c]
+ * +- LocalRelation [_1, _2, _3]
+ *
+ * As it can be seen in the example above, `c` from the [[Filter]] condition should be resolved
+ * using the `hiddenOutput` (because its child output doesn't contain `c`). That's why in hidden
+ * output we have to have both hidden output from the previous scope and the provided output.
+ * This is done for [[Project]] and [[Aggregate]] operators.
+ *
+ * 2. updates nullabilities of attributes in hidden output from new output, so that if attribute
+ * was nullable in either old hidden output or new output, it must stay nullable in new hidden
+ * output as well.
+ */
+ def overwriteOutputAndExtendHiddenOutput(output: Seq[Attribute]): Unit = {
+ val prevScope = stack.pop
+ val hiddenOutputWithUpdatedNullabilities =
+ updateNullabilitiesInHiddenOutput(output, prevScope.hiddenOutput)
+ val hiddenOutput = hiddenOutputWithUpdatedNullabilities ++ output.filter { attribute =>
+ prevScope.getHiddenAttributeById(attribute.exprId).isEmpty
+ }
+ val newScope = prevScope.overwriteOutput(
+ output = Some(output),
+ hiddenOutput = Some(hiddenOutput)
+ )
- stack.pop()
stack.push(newScope)
}
@@ -455,13 +762,176 @@ class NameScopeStack extends SQLConfHelper {
* resolve(unresolvedExcept.right)
* }
* }}}
+ *
+ * After finishing execution of the body within the `withNewScope`, pops the stack. It also
+ * propagates `hiddenOutput` upwards because of name resolution by overwriting the current
+ * [[NameScope.hiddenOutput]] with the popped one. This is not done in case `withNewScope` was
+ * called in the context of subquery resolution (which is indicated by `isSubqueryRoot` flag),
+ * because we don't want to overwrite the existing `hiddenOutput` of the main plan.
+ *
+ * @param isSubqueryRoot Indicates that the current scope is a root of a subquery. This is used by
+ * [[NameScopeStack.resolveMultipartName]] to detect the nearest outer scope.
*/
- def withNewScope[R](body: => R): R = {
- stack.push(new NameScope)
+ def withNewScope[R](isSubqueryRoot: Boolean = false)(body: => R): R = {
+ stack.push(new NameScope(isSubqueryRoot = isSubqueryRoot))
try {
body
} finally {
- stack.pop()
+ val childScope = stack.pop()
+ if (stack.size() > 0 && !childScope.isSubqueryRoot) {
+ val currentScope = stack.pop()
+ stack.push(currentScope.overwriteOutput(hiddenOutput = Some(childScope.hiddenOutput)))
+ }
+ }
+ }
+
+ /**
+ * Resolve multipart name into a [[NameTarget]] from current or outer scopes. Currently we only
+ * support one level of correlation, so we look up `multipartName` in current scope, and if the
+ * name was not found, we look it up in the nearest outer scope:
+ *
+ * {{{
+ * -- 'a' is a simple lookup from the current scope.
+ * SELECT a FROM (SELECT col1 AS a FROM VALUES (1));
+ * }}}
+ *
+ * {{{
+ * -- `a` in `(SELECT a + 1)` will be wrapped in [[OuterReference]].
+ * SELECT a, (SELECT a + 1) AS b FROM (SELECT col1 AS a FROM VALUES (1));
+ * }}}
+ *
+ * The ambiguity between local and outer references is resolved in favour of current:
+ * {{{
+ * -- There's no correlation here, subquery references its column from the current scope.
+ * -- This returns [1, 2].
+ * SELECT col1, (SELECT col1 FROM VALUES (2)) AS b FROM VALUES (1)
+ * }}}
+ *
+ * Correlations beyond one level are not supported:
+ * {{{
+ * -- 3 levels, fails with `UNRESOLVED_COLUMN`.
+ * SELECT (
+ * SELECT (
+ * SELECT t1.col1 FROM VALUES (3) AS t3
+ * ) FROM VALUES (2) AS t2
+ * ) FROM VALUES (1) AS t1;
+ * }}}
+ *
+ * Correlated references are accessible from lower subquery operators:
+ * {{{
+ * -- Returns [1, 1]
+ * SELECT
+ * col1, (SELECT * FROM (SELECT t1.col1 FROM VALUES (2) AS t2))
+ * FROM
+ * VALUES (1) AS t1;
+ * }}}
+ *
+ * We cannot reference LCA or aggregate expression by alias in the outer scope:
+ * {{{
+ * -- These examples fail with `UNRESOLVED_COLUMN`.
+ * -- LCA in outer scope.
+ * SELECT col1 AS a, (SELECT a + 1) AS b FROM VALUES (1);
+ * -- Aliased aggerate expression in outer scope.
+ * SELECT col1 AS a FROM VALUES (1) GROUP BY a, (SELECT a + 1);
+ * }}}
+ *
+ * Only [[Attribute]]s are wrapped in [[OuterReference]]:
+ * {{{
+ * -- The subquery's [[Project]] list will contain outer(col1#0).f1.f2.
+ * SELECT
+ * col1, (SELECT col1.f1.f2 + 1) AS b
+ * FROM
+ * VALUES (named_struct('f1', named_struct('f2', 1)));
+ * }}}
+ */
+ def resolveMultipartName(
+ multipartName: Seq[String],
+ canLaterallyReferenceColumn: Boolean = true,
+ canReferenceAggregateExpressionAliases: Boolean = false,
+ canResolveNameByHiddenOutput: Boolean = false): NameTarget = {
+ val nameTargetFromCurrentScope = current.resolveMultipartName(
+ multipartName,
+ canLaterallyReferenceColumn = canLaterallyReferenceColumn,
+ canReferenceAggregateExpressionAliases = canReferenceAggregateExpressionAliases,
+ canResolveNameByHiddenOutput = canResolveNameByHiddenOutput
+ )
+
+ if (nameTargetFromCurrentScope.candidates.nonEmpty) {
+ nameTargetFromCurrentScope
+ } else {
+ outer match {
+ case Some(outer) =>
+ val nameTarget = outer.resolveMultipartName(
+ multipartName,
+ canLaterallyReferenceColumn = false,
+ canReferenceAggregateExpressionAliases = false
+ )
+
+ if (nameTarget.candidates.nonEmpty) {
+ nameTarget.copy(
+ isOuterReference = true,
+ candidates = nameTarget.candidates.map(wrapCandidateInOuterReference)
+ )
+ } else {
+ nameTargetFromCurrentScope
+ }
+
+ case None =>
+ nameTargetFromCurrentScope
+ }
+ }
+ }
+
+ /**
+ * Find the nearest outer scope and return it if we are in a subquery.
+ */
+ private def outer: Option[NameScope] = {
+ var outerScope: Option[NameScope] = None
+
+ val iter = stack.iterator
+ while (iter.hasNext && !outerScope.isDefined) {
+ val scope = iter.next
+
+ if (scope.isSubqueryRoot && iter.hasNext) {
+ outerScope = Some(iter.next)
+ }
+ }
+
+ outerScope
+ }
+
+ /**
+ * Wrap candidate in [[OuterReference]]. If the root is not an [[Attribute]], but an
+ * [[ExtractValue]] (struct/map/array field reference) we find the actual [[Attribute]] and wrap
+ * it in [[OuterReference]].
+ */
+ private def wrapCandidateInOuterReference(candidate: Expression): Expression = candidate match {
+ case candidate: Attribute =>
+ OuterReference(candidate)
+ case extractValue: ExtractValue =>
+ extractValue.transformUp {
+ case attribute: Attribute => OuterReference(attribute)
+ case other => other
+ }
+ case _ =>
+ candidate
+ }
+
+ /**
+ * When the scope gets the new output, we need to refresh nullabilities in its `hiddenOutput`. If
+ * an attribute is nullable in either old hidden output or new output, it must remain nullable in
+ * new hidden output as well.
+ */
+ private def updateNullabilitiesInHiddenOutput(
+ output: Seq[Attribute],
+ hiddenOutput: Seq[Attribute]) = {
+ val outputLookup = new HashMap[ExprId, Attribute](output.size)
+ output.foreach(attribute => outputLookup.put(attribute.exprId, attribute))
+
+ hiddenOutput.map {
+ case attribute if outputLookup.containsKey(attribute.exprId) =>
+ attribute.withNullability(attribute.nullable || outputLookup.get(attribute.exprId).nullable)
+ case attribute => attribute
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameTarget.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameTarget.scala
index 9e949835c4137..5b5f4e444bb9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameTarget.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameTarget.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.util.StringUtils.orderSuggestedIdentifiersBySimilarity
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.Metadata
/**
* [[NameTarget]] is a result of a multipart name resolution of the
@@ -54,16 +55,22 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
* @param aliasName If the candidates size is 1 and it's type is [[ExtractValue]] (which means that
* it's a field/value/item from a recursive type), then the `aliasName` should be the name with
* which the candidate needs to be aliased. Otherwise, `aliasName` is `None`.
+ * @param aliasMetadata If the candidates were created out of expressions referenced by group by
+ * alias, store the metadata of the alias. Otherwise, `aliasMetadata` is `None`.
* @param lateralAttributeReference If the candidate is laterally referencing another column this
* field is populated with that column's attribute.
* @param output [[output]] of a [[NameScope]] that produced this [[NameTarget]]. Used to provide
* suggestions for thrown errors.
+ * @param isOuterReference A flag indicating that this [[NameTarget]] resolves to an outer
+ * reference.
*/
case class NameTarget(
candidates: Seq[Expression],
aliasName: Option[String] = None,
+ aliasMetadata: Option[Metadata] = None,
lateralAttributeReference: Option[Attribute] = None,
- output: Seq[Attribute] = Seq.empty) {
+ output: Seq[Attribute] = Seq.empty,
+ isOuterReference: Boolean = false) {
/**
* Pick a single candidate from `candidates`:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanLogger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanLogger.scala
index a0d67893484cf..f9bdfa06ae54b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanLogger.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanLogger.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import org.apache.spark.internal.{Logging, MDC, MessageWithContext}
+import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{MESSAGE, QUERY_PLAN}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -32,65 +32,44 @@ class PlanLogger extends Logging {
private val expressionTreeChangeLogLevel = SQLConf.get.expressionTreeChangeLogLevel
def logPlanResolutionEvent(plan: LogicalPlan, event: String): Unit = {
- log(() => log"""
- |=== Plan resolution: ${MDC(MESSAGE, event)} ===
- |${MDC(QUERY_PLAN, plan.treeString)}
- """.stripMargin, planChangeLogLevel)
+ logBasedOnLevel(planChangeLogLevel) {
+ log"""
+ |=== Plan resolution: ${MDC(MESSAGE, event)} ===
+ |${MDC(QUERY_PLAN, plan.treeString)}
+ """.stripMargin
+ }
}
def logPlanResolution(unresolvedPlan: LogicalPlan, resolvedPlan: LogicalPlan): Unit = {
- log(
- () =>
- log"""
- |=== Unresolved plan -> Resolved plan ===
- |${MDC(
- QUERY_PLAN,
- sideBySide(
- unresolvedPlan.treeString,
- resolvedPlan.treeString
- ).mkString("\n")
- )}
- """.stripMargin,
- planChangeLogLevel
- )
+ logBasedOnLevel(planChangeLogLevel) {
+ val unresolved = unresolvedPlan.treeString
+ val resolved = resolvedPlan.treeString
+ log"""
+ |=== Unresolved plan -> Resolved plan ===
+ |${MDC(QUERY_PLAN, sideBySide(unresolved, resolved).mkString("\n"))}
+ """.stripMargin
+ }
}
def logExpressionTreeResolutionEvent(expressionTree: Expression, event: String): Unit = {
- log(
- () => log"""
- |=== Expression tree resolution: ${MDC(MESSAGE, event)} ===
- |${MDC(QUERY_PLAN, expressionTree.treeString)}
- """.stripMargin,
- expressionTreeChangeLogLevel
- )
+ logBasedOnLevel(expressionTreeChangeLogLevel) {
+ log"""
+ |=== Expression tree resolution: ${MDC(MESSAGE, event)} ===
+ |${MDC(QUERY_PLAN, expressionTree.treeString)}
+ """.stripMargin
+ }
}
def logExpressionTreeResolution(
unresolvedExpressionTree: Expression,
resolvedExpressionTree: Expression): Unit = {
- log(
- () =>
- log"""
- |=== Unresolved expression tree -> Resolved expression tree ===
- |${MDC(
- QUERY_PLAN,
- sideBySide(
- unresolvedExpressionTree.treeString,
- resolvedExpressionTree.treeString
- ).mkString("\n")
- )}
- """.stripMargin,
- expressionTreeChangeLogLevel
- )
- }
-
- private def log(createMessage: () => MessageWithContext, logLevel: String): Unit =
- logLevel match {
- case "TRACE" => logTrace(createMessage().message)
- case "DEBUG" => logDebug(createMessage().message)
- case "INFO" => logInfo(createMessage())
- case "WARN" => logWarning(createMessage())
- case "ERROR" => logError(createMessage())
- case _ => logTrace(createMessage().message)
+ logBasedOnLevel(expressionTreeChangeLogLevel) {
+ val unresolved = unresolvedExpressionTree.treeString
+ val resolved = resolvedExpressionTree.treeString
+ log"""
+ |=== Unresolved expression tree -> Resolved expression tree ===
+ |${MDC(QUERY_PLAN, sideBySide(unresolved, resolved).mkString("\n"))}
+ """.stripMargin
}
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanRewriter.scala
new file mode 100644
index 0000000000000..8755c24f9ca31
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanRewriter.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
+import org.apache.spark.sql.catalyst.plans.logical.{AnalysisHelper, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
+
+/**
+ * Utility wrapper on top of [[RuleExecutor]], used to apply post-resolution rules on single-pass
+ * resolution result. [[SinglePassRewriter]] transforms the plan and the subqueries inside.
+ */
+class PlanRewriter(planRewriteRules: Seq[Rule[LogicalPlan]]) {
+ private val planRewriter = new RuleExecutor[LogicalPlan] {
+ override def batches: Seq[Batch] =
+ Seq(
+ Batch(
+ "Plan Rewriting",
+ Once,
+ planRewriteRules: _*
+ )
+ )
+ }
+
+ /**
+ * Rewrites the plan by first recursing into all subqueries and applying post-resolution rules on
+ * them and then applying post-resolution rules on the entire plan.
+ */
+ def rewriteWithSubqueries(plan: LogicalPlan): LogicalPlan =
+ AnalysisHelper.allowInvokingTransformsInAnalyzer {
+ val planWithRewrittenSubqueries =
+ plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
+ case subqueryExpression: SubqueryExpression =>
+ val rewrittenSubqueryPlan = rewrite(subqueryExpression.plan)
+
+ subqueryExpression.withNewPlan(rewrittenSubqueryPlan)
+ }
+
+ rewrite(planWithRewrittenSubqueries)
+ }
+
+ /**
+ * Rewrites the plan __without__ recursing into the subqueries.
+ */
+ private def rewrite(plan: LogicalPlan): LogicalPlan = planRewriter.execute(plan)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PredicateResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PredicateResolver.scala
index 1c4d8dd50113b..7669fcce53295 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PredicateResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PredicateResolver.scala
@@ -43,17 +43,17 @@ class PredicateResolver(
with ResolvesExpressionChildren {
private val typeCoercionTransformations: Seq[Expression => Expression] =
- if (conf.ansiEnabled) {
- PredicateResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS
- } else {
- PredicateResolver.TYPE_COERCION_TRANSFORMATIONS
- }
+ if (conf.ansiEnabled) {
+ PredicateResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS
+ } else {
+ PredicateResolver.TYPE_COERCION_TRANSFORMATIONS
+ }
private val typeCoercionResolver: TypeCoercionResolver =
new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations)
override def resolve(unresolvedPredicate: Predicate): Expression = {
val predicateWithResolvedChildren =
- withResolvedChildren(unresolvedPredicate, expressionResolver.resolve)
+ withResolvedChildren(unresolvedPredicate, expressionResolver.resolve _)
val predicateWithTypeCoercion = typeCoercionResolver.resolve(predicateWithResolvedChildren)
val predicateWithCharTypePadding = {
ApplyCharTypePaddingHelper.singleNodePaddingForStringComparison(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProducesUnresolvedSubtree.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProducesUnresolvedSubtree.scala
index 8d85804a93634..576cb98bbabe8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProducesUnresolvedSubtree.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProducesUnresolvedSubtree.scala
@@ -37,6 +37,10 @@ trait ProducesUnresolvedSubtree extends ResolvesExpressionChildren {
* node. Method ensures that the downwards traversal never visits previously resolved nodes by
* tracking the limits of the traversal with a tag. Invokes a resolver callback to resolve
* children, but DOES NOT resolve the root of the subtree.
+ *
+ * If the result of the callback is the same object as the source `expression`, we don't perform
+ * the downwards traversal. This is both more optimal and a fail-safe mechanism in case we
+ * accidentally lose the [[ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY]] tag.
*/
protected def withResolvedSubtree(
expression: Expression,
@@ -45,8 +49,31 @@ trait ProducesUnresolvedSubtree extends ResolvesExpressionChildren {
child.setTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY, ())
}
- val result = body
+ val resultExpression = body
- withResolvedChildren(result, expressionResolver)
+ if (resultExpression.eq(expression)) {
+ expression.children.foreach { child =>
+ child.unsetTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY)
+ }
+ resultExpression
+ } else {
+ withResolvedChildren(resultExpression, expressionResolver)
+ }
+ }
+
+ /**
+ * Try to pop the tag that marks the boundary of the single-pass subtree resolution.
+ * [[ExpressionResolver]] calls this method to check if the subtree traversal needs to be stopped
+ * because lower subtree is already resolved.
+ */
+ protected def tryPopSinglePassSubtreeBoundary(unresolvedExpression: Expression): Boolean = {
+ if (unresolvedExpression
+ .getTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY)
+ .isDefined) {
+ unresolvedExpression.unsetTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY)
+ true
+ } else {
+ false
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala
index 04623eca1c545..f6af7e4f2f41f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala
@@ -17,12 +17,21 @@
package org.apache.spark.sql.catalyst.analysis.resolver
+import java.util.{HashMap, HashSet}
+
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
-import org.apache.spark.sql.catalyst.analysis.{withPosition, AnalysisErrorAt}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ Attribute,
+ Expression,
+ ExprId,
+ ExprUtils,
+ NamedExpression
+}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
/**
@@ -34,14 +43,13 @@ import org.apache.spark.sql.internal.SQLConf
*/
class ProjectResolver(operatorResolver: Resolver, expressionResolver: ExpressionResolver)
extends TreeNodeResolver[Project, LogicalPlan] {
-
- private val isLcaEnabled = conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)
private val scopes = operatorResolver.getNameScopes
/**
- * * [[Project]] introduces a new scope to resolve its subtree and project list expressions.
- * * During the resolution we determine whether the output operator will be [[Aggregate]] or
- * * [[Project]] (based on the `hasAggregateExpressions` flag).
+ * [[Project]] introduces a new scope to resolve its subtree and project list expressions.
+ * During the resolution we determine whether the output operator will be [[Aggregate]] or
+ * [[Project]] (based on the `hasAggregateExpressions` flag). If the result is an [[Aggregate]]
+ * validate it using the [[ExprUtils.assertValidAggregation]].
*
* If the output operator is [[Project]] and if lateral column alias resolution is enabled, we
* construct a multi-level [[Project]], created from all lateral column aliases and their
@@ -51,10 +59,15 @@ class ProjectResolver(operatorResolver: Resolver, expressionResolver: Expression
* current scope with resolved operators output to expose new names to the parent operators.
*/
override def resolve(unresolvedProject: Project): LogicalPlan = {
- val (resolvedOperator, resolvedProjectList) = scopes.withNewScope {
+ val (resolvedOperator, resolvedProjectList) = scopes.withNewScope() {
val resolvedChild = operatorResolver.resolve(unresolvedProject.child)
+ val childReferencedAttributes = expressionResolver.getLastReferencedAttributes
val resolvedProjectList =
expressionResolver.resolveProjectList(unresolvedProject.projectList, unresolvedProject)
+
+ val resolvedChildWithMetadataColumns =
+ retainOriginalJoinOutput(resolvedChild, resolvedProjectList, childReferencedAttributes)
+
if (resolvedProjectList.hasAggregateExpressions) {
if (resolvedProjectList.hasLateralColumnAlias) {
// Disable LCA in Aggregates until fully supported.
@@ -63,32 +76,157 @@ class ProjectResolver(operatorResolver: Resolver, expressionResolver: Expression
val aggregate = Aggregate(
groupingExpressions = Seq.empty[Expression],
aggregateExpressions = resolvedProjectList.expressions,
- child = resolvedChild,
+ child = resolvedChildWithMetadataColumns,
hint = None
)
- if (resolvedProjectList.hasAttributes) {
- aggregate.failAnalysis(errorClass = "MISSING_GROUP_BY", messageParameters = Map.empty)
- }
+
+ // TODO: This validation function does a post-traversal. This is discouraged in
+ // single-pass Analyzer.
+ ExprUtils.assertValidAggregation(aggregate)
+
(aggregate, resolvedProjectList)
} else {
- val projectWithLca = if (isLcaEnabled) {
- buildProjectWithResolvedLCAs(resolvedChild, resolvedProjectList.expressions)
+ val projectWithLca = if (conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
+ buildProjectWithResolvedLCAs(
+ resolvedChildWithMetadataColumns,
+ resolvedProjectList.expressions
+ )
} else {
- Project(resolvedProjectList.expressions, resolvedChild)
+ Project(resolvedProjectList.expressions, resolvedChildWithMetadataColumns)
}
(projectWithLca, resolvedProjectList)
}
}
- withPosition(unresolvedProject) {
- scopes.overwriteTop(
+ scopes.overwriteOutputAndExtendHiddenOutput(
+ output =
resolvedProjectList.expressions.map(namedExpression => namedExpression.toAttribute)
- )
- }
+ )
resolvedOperator
}
+ /**
+ * This method adds a [[Project]] node on top of a [[Join]], if [[Join]]'s output has been
+ * changed when metadata columns are added to [[Project]] nodes below the [[Join]]. This is
+ * necessary in order to stay compatible with fixed-point analyzer. Instead of doing this in
+ * [[JoinResolver]] we must do this here, because while resolving [[Join]] we still don't know if
+ * we should add a [[Project]] or not. For example consider the following query:
+ *
+ * {{{
+ * -- tables: nt1(k, v1), nt2(k, v2), nt3(k, v3)
+ * SELECT * FROM nt1 NATURAL JOIN nt2 JOIN nt3 ON nt2.k = nt3.k;
+ * }}}
+ *
+ * Unresolved plan will be:
+ *
+ * 'Project [*]
+ * +- 'Join Inner, ('nt2.k = 'nt3.k)
+ * :- 'Join NaturalJoin(Inner)
+ * : :- 'UnresolvedRelation [nt1], [], false
+ * : +- 'UnresolvedRelation [nt2], [], false
+ * +- 'UnresolvedRelation [nt3], [], false
+ *
+ * After resolving the inner natural join, the plan becomes:
+ *
+ * 'Project [*]
+ * +- 'Join Inner, ('nt2.k = 'nt3.k)
+ * :- Project [k#15, v1#16, v2#28, k#27]
+ * : +- Join Inner, (k#15 = k#27)
+ * : : +- SubqueryAlias nt1
+ * : : +- LocalRelation [k#15, v1#16]
+ * : +- SubqueryAlias nt2
+ * : +- LocalRelation [k#27, v2#28]
+ * +- 'UnresolvedRelation [nt3], [], false
+ *
+ * Because we are resolving a natural join, we have placed a [[Project]] node on top of it with
+ * the inner join's output. Additionally, in single-pass, we add all metadata columns as we
+ * resolve up and then prune away unnecessary columns later (more in [[PruneMetadataColumns]]).
+ * This is necessary in order to stay compatible with fixed-point's [[AddMetadataColumns]] rule,
+ * because [[AddMetadata]] columns will recognize k#27 as missing attribute needed for [[Join]]
+ * condition and will therefore add it in the below [[Project]] node. Because of this we are also
+ * adding k#27 as a metadata column to this [[Project]]. This addition of a metadata column
+ * changes the original output of the outer join (because one of the inputs has changed) and in
+ * order to stay compatible with fixed-point, we need to place another [[Project]] on top of the
+ * outer join with its original output. Now, the final plan looks like this:
+ *
+ * Project [k#15, v1#16, v2#28, k#31, v3#32]
+ * +- Project [k#15, v1#16, v2#28, k#31, v3#32]
+ * +- Join Inner, (k#27 = k#31)
+ * :- Project [k#15, v1#16, v2#28, k#27]
+ * : +- Join Inner, (k#15 = k#27)
+ * : :- SubqueryAlias nt1
+ * : : +- LocalRelation [k#15, v1#16]
+ * : +- SubqueryAlias nt2
+ * : +- LocalRelation [k#27, v2#28]
+ * +- SubqueryAlias nt3
+ * +- LocalRelation [k#31, v3#32]
+ *
+ * As can be seen, the [[Project]] node immediately on top of [[Join]] doesn't contain the
+ * metadata column k#27 that we have added. Because of this, k#27 will be pruned away later.
+ *
+ * Now consider the following query for the same input:
+ *
+ * {{{ SELECT *, nt2.k FROM nt1 NATURAL JOIN nt2 JOIN nt3 ON nt2.k = nt3.k; }}}
+ *
+ * The plan will be:
+ *
+ * Project [k#15, v1#16, v2#28, k#31, v3#32, k#27]
+ * +- Join Inner, (k#27 = k#31)
+ * :- Project [k#15, v1#16, v2#28, k#27]
+ * : +- Join Inner, (k#15 = k#27)
+ * : :- SubqueryAlias nt1
+ * : : +- LocalRelation [k#15, v1#16]
+ * : +- SubqueryAlias nt2
+ * : +- LocalRelation [k#27, v2#28]
+ * +- SubqueryAlias nt3
+ * +- LocalRelation [k#31, v3#32]
+ *
+ * In fixed-point, because we are referencing k#27 from [[Project]] node, [[AddMetadataColumns]]
+ * (which is transforming the tree top-down) will see that [[Project]] has a missing metadata
+ * column and will therefore place k#27 in the [[Project]] node below outer [[Join]]. This is
+ * important, because by [[AddMetadataColumns]] logic, we don't check whether the output of the
+ * outer [[Join]] has changed, and we only check the output change for top-most [[Project]].
+ * Because we need to stay fully compatible with fixed-point, in this case w don't place a
+ * [[Project]] on top of the outer [[Join]] even though its output has changed.
+ */
+ private def retainOriginalJoinOutput(
+ plan: LogicalPlan,
+ resolvedProjectList: ResolvedProjectList,
+ childReferencedAttributes: HashMap[ExprId, Attribute]): LogicalPlan = {
+ plan match {
+ case join: Join
+ if childHasMissingAttributesNotInProjectList(
+ resolvedProjectList.expressions,
+ childReferencedAttributes
+ ) =>
+ Project(scopes.current.output, join)
+ case other => other
+ }
+ }
+
+ /**
+ * Returns true if a child node of [[Project]] has missing attributes that can be resolved from
+ * [[NameScope.hiddenOutput]] and those attributes are not present in the project list.
+ */
+ private def childHasMissingAttributesNotInProjectList(
+ projectList: Seq[NamedExpression],
+ referencedAttributes: HashMap[ExprId, Attribute]): Boolean = {
+ val expressionIdsFromProjectList = new HashSet[ExprId](projectList.map(_.exprId).asJava)
+ val missingAttributes = new HashMap[ExprId, Attribute]
+ referencedAttributes.asScala
+ .foreach {
+ case (exprId, attribute) =>
+ if (!expressionIdsFromProjectList.contains(exprId) && attribute.isMetadataCol) {
+ missingAttributes.put(exprId, attribute)
+ }
+ }
+ val missingAttributeResolvedByHiddenOutput =
+ scopes.current.resolveMissingAttributesByHiddenOutput(missingAttributes)
+
+ missingAttributeResolvedByHiddenOutput.nonEmpty
+ }
+
/**
* Builds a multi-level [[Project]] with all lateral column aliases and their dependencies. First,
* from top scope, we acquire dependency levels of all aliases. Dependency level is defined as a
@@ -127,15 +265,15 @@ class ProjectResolver(operatorResolver: Resolver, expressionResolver: Expression
private def buildProjectWithResolvedLCAs(
resolvedChild: LogicalPlan,
originalProjectList: Seq[NamedExpression]) = {
- val aliasDependencyMap = scopes.top.lcaRegistry.getAliasDependencyLevels()
+ val aliasDependencyMap = scopes.current.lcaRegistry.getAliasDependencyLevels()
val (finalChildPlan, _) = aliasDependencyMap.asScala.foldLeft(
- (resolvedChild, scopes.top.output.map(_.asInstanceOf[NamedExpression]))
+ (resolvedChild, scopes.current.output.map(_.asInstanceOf[NamedExpression]))
) {
case ((currentPlan, currentProjectList), availableAliases) =>
val referencedAliases = new ArrayBuffer[Alias]
availableAliases.forEach(
alias =>
- if (scopes.top.lcaRegistry.isAttributeLaterallyReferenced(alias.toAttribute)) {
+ if (scopes.current.lcaRegistry.isAttributeLaterallyReferenced(alias.toAttribute)) {
referencedAliases.append(alias)
}
)
@@ -150,7 +288,7 @@ class ProjectResolver(operatorResolver: Resolver, expressionResolver: Expression
val finalProjectList = originalProjectList.map(
alias =>
- if (scopes.top.lcaRegistry.isAttributeLaterallyReferenced(alias.toAttribute)) {
+ if (scopes.current.lcaRegistry.isAttributeLaterallyReferenced(alias.toAttribute)) {
alias.toAttribute
} else {
alias
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PruneMetadataColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PruneMetadataColumns.scala
new file mode 100644
index 0000000000000..2bbf627da8fcf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PruneMetadataColumns.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import java.util.HashSet
+
+import org.apache.spark.sql.catalyst.expressions.{ExprId, NamedExpression, OuterReference}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
+import org.apache.spark.sql.catalyst.util._
+
+/**
+ * This is a special rule for single-pass resolver that performs a single downwards traversal in
+ * order to prune away unnecessary metadata columns. This is necessary because fixed-point is
+ * looking into the operator tree in order to determine whether it is necessary to add metadata
+ * columns. In single-pass, we always add metadata columns during the main traversal and this rule
+ * performs the cleanup of those columns that are unnecessary. Important thing to note here is that
+ * by "unnecessary" columns we are not referring to the ones that are not needed in upper operators
+ * for correct result, but to the columns that have been added by single-pass resolver but are not
+ * present in the fixed-point plan.
+ */
+object PruneMetadataColumns extends Rule[LogicalPlan] {
+
+ /**
+ * Entry point for [[PruneMetadataColumns]] rule.
+ */
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ pruneMetadataColumns(plan = plan, neededAttributes = new HashSet[ExprId])
+
+ /**
+ * This method recursively prunes away unnecessary metadata columns, going from top to bottom.
+ *
+ * @param plan Operator for which we are pruning columns.
+ * @param neededAttributes A set of [[ExprId]]s that is required by operators above [[plan]].
+ *
+ * We distinguish three separate cases here, based on the type of operator node:
+ * - For [[Aggregate]] nodes we only need to propagate aggregate expressions as needed
+ * attributes to the lower operators. This is because in single-pass, we are not adding metadata
+ * columns to [[Aggregate]] operators.
+ * - For [[Project]] nodes we prune away all metadata columns that are either not required by
+ * operators above or they are duplicated in the project list.
+ * - For all other operators we collect references, add them to [[neededAttributes]] and
+ * recursively call [[pruneMetadataColumns]] for children.
+ */
+ private def pruneMetadataColumns(
+ plan: LogicalPlan,
+ neededAttributes: HashSet[ExprId]): LogicalPlan = plan match {
+ case aggregate: Aggregate =>
+ withNewChildrenPrunedByNeededAttributes(
+ aggregate,
+ aggregate.aggregateExpressions
+ )
+ case project: Project =>
+ pruneMetadataColumnsInProject(project, neededAttributes)
+ case other =>
+ pruneMetadataColumnsGenerically(other, neededAttributes)
+ }
+
+ /**
+ * Prune unnecessary columns for a [[Project]] node and recursively do it for its children. While
+ * pruning we preserve all non-qualified-access-only columns as well as any columns that are
+ * needed in the operators above, but without duplicating them in the project list. This behavior
+ * is consistent with fixed-point's behavior when
+ * [[SQLConf.ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS]] is true. We don't support legacy
+ * behavior in single-pass.
+ *
+ * IMPORTANT NOTE: In this case we only prune away only the qualified access only columns instead
+ * of all metadata columns. This is because we can have metadata columns from sources other than
+ * hidden output and additionally, when a column is resolved from any source, its qualified only
+ * flag is removed (see more in [[AttributeSeq]]). Therefore, qualified access only columns that
+ * appear in the project list must have come from artificial appending, and we should potentially
+ * prune them.
+ */
+ private def pruneMetadataColumnsInProject(project: Project, neededAttributes: HashSet[ExprId]) = {
+ val existingExprIds = new HashSet[ExprId]
+ val newProjectList = if (!neededAttributes.isEmpty) {
+ project.projectList.collect {
+ case namedExpression: NamedExpression if !namedExpression.toAttribute.qualifiedAccessOnly =>
+ existingExprIds.add(namedExpression.exprId)
+ namedExpression
+ case namedExpression: NamedExpression
+ if namedExpression.toAttribute.qualifiedAccessOnly && neededAttributes.contains(
+ namedExpression.exprId
+ ) && !existingExprIds.contains(namedExpression.exprId) =>
+ existingExprIds.add(namedExpression.exprId)
+ namedExpression
+ }
+ } else {
+ project.projectList
+ }
+ val projectWithNewChildren =
+ withNewChildrenPrunedByNeededAttributes(project, newProjectList).asInstanceOf[Project]
+ val newProject = CurrentOrigin.withOrigin(projectWithNewChildren.origin) {
+ projectWithNewChildren.copy(projectList = newProjectList)
+ }
+ newProject.copyTagsFrom(project)
+ newProject
+ }
+
+ /**
+ * Prune unnecessary metadata column in operators that are not [[Project]] or [[Aggregate]].
+ */
+ private def pruneMetadataColumnsGenerically(
+ operator: LogicalPlan,
+ neededAttributes: HashSet[ExprId]) = {
+ operator.references.foreach(attr => neededAttributes.add(attr.exprId))
+ val newChildren = operator.children.map { child =>
+ pruneMetadataColumns(child, new HashSet(neededAttributes))
+ }
+ operator.withNewChildren(newChildren)
+ }
+
+ private def withNewChildrenPrunedByNeededAttributes(
+ plan: LogicalPlan,
+ newNeededAttributes: Seq[NamedExpression]): LogicalPlan = {
+ val neededAttributes = new HashSet[ExprId]
+ newNeededAttributes.foreach {
+ case _: OuterReference =>
+ case other: NamedExpression =>
+ other.foreach {
+ case namedExpression: NamedExpression =>
+ neededAttributes.add(namedExpression.exprId)
+ case _ =>
+ }
+ }
+ plan.withNewChildren(
+ plan.children.map(
+ child => pruneMetadataColumns(child, neededAttributes)
+ )
+ )
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PullOutNondeterministicExpressionInExpressionTree.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PullOutNondeterministicExpressionInExpressionTree.scala
new file mode 100644
index 0000000000000..3272c6975075c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PullOutNondeterministicExpressionInExpressionTree.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import java.util.LinkedHashMap
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
+
+/**
+ * Pull out nondeterministic expressions in an expression tree and replace them with the
+ * corresponding attributes in the `nondeterministicToAttributes` map.
+ */
+object PullOutNondeterministicExpressionInExpressionTree {
+ def apply[ExpressionType <: Expression](
+ expression: ExpressionType,
+ nondeterministicToAttributes: LinkedHashMap[Expression, NamedExpression]): ExpressionType = {
+ expression
+ .transform {
+ case childExpression =>
+ nondeterministicToAttributes.get(childExpression) match {
+ case null =>
+ childExpression
+ case namedExpression =>
+ namedExpression.toAttribute
+ }
+ }
+ .asInstanceOf[ExpressionType]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala
index 3bf15c51977f4..6a3d8e161f2c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver
import java.util.HashMap
-import org.apache.spark.sql.catalyst.analysis.{withPosition, RelationResolution, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{RelationResolution, UnresolvedRelation}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.LookupCatalog
import org.apache.spark.util.ArrayImplicits._
@@ -85,9 +85,11 @@ trait RelationMetadataProvider extends LookupCatalog {
isStreaming = unresolvedRelation.isStreaming
)
case _ =>
- withPosition(unresolvedRelation) {
- unresolvedRelation.tableNotFound(unresolvedRelation.multipartIdentifier)
- }
+ RelationId(
+ multipartIdentifier = unresolvedRelation.multipartIdentifier,
+ options = unresolvedRelation.options,
+ isStreaming = unresolvedRelation.isStreaming
+ )
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala
index 93034e931bb0e..2e536c33ef416 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala
@@ -26,26 +26,10 @@ import org.apache.spark.sql.catalyst.analysis.{
SchemaBinding
}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.plans.logical.{
- Aggregate,
- CTERelationDef,
- CTERelationRef,
- Distinct,
- Filter,
- GlobalLimit,
- LocalLimit,
- LocalRelation,
- LogicalPlan,
- OneRowRelation,
- Project,
- SubqueryAlias,
- Union,
- View,
- WithCTE
-}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, StructType}
+import org.apache.spark.sql.types.{BooleanType, MetadataBuilder, StructType}
/**
* The [[ResolutionValidator]] performs the validation work after the logical plan tree is
@@ -55,10 +39,11 @@ import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, Struc
* The validation approach is single-pass, post-order, complementary to the resolution process.
*/
class ResolutionValidator {
+ private val attributeScopeStack = new AttributeScopeStack
+
private val expressionResolutionValidator = new ExpressionResolutionValidator(this)
- private[resolver] var attributeScopeStack = new AttributeScopeStack
- private val cteRelationDefIds = new HashSet[Long]
+ def getAttributeScopeStack: AttributeScopeStack = attributeScopeStack
/**
* Validate the resolved logical `plan` - assert invariants that should never be false no
@@ -70,7 +55,11 @@ class ResolutionValidator {
validate(plan)
}
- private def validate(operator: LogicalPlan): Unit = {
+ /**
+ * Validate a specific `operator`. This is an internal entry point for the recursive validation.
+ * Also, [[ExpressionResolutionValidator]] calls it to validate [[SubqueryExpression]] plans.
+ */
+ def validate(operator: LogicalPlan): Unit = {
operator match {
case withCte: WithCTE =>
validateWith(withCte)
@@ -92,6 +81,10 @@ class ResolutionValidator {
validateGlobalLimit(globalLimit)
case localLimit: LocalLimit =>
validateLocalLimit(localLimit)
+ case offset: Offset =>
+ validateOffset(offset)
+ case tail: Tail =>
+ validateTail(tail)
case distinct: Distinct =>
validateDistinct(distinct)
case inlineTable: ResolvedInlineTable =>
@@ -100,57 +93,69 @@ class ResolutionValidator {
validateRelation(localRelation)
case oneRowRelation: OneRowRelation =>
validateRelation(oneRowRelation)
- case union: Union =>
- validateUnion(union)
+ case range: Range =>
+ validateRelation(range)
+ case setOperationLike @ (_: Union | _: SetOperation) =>
+ validateSetOperationLike(setOperationLike)
+ case sort: Sort =>
+ validateSort(sort)
+ case join: Join =>
+ validateJoin(join)
+ case repartition: Repartition =>
+ validateRepartition(repartition)
// [[LogicalRelation]], [[HiveTableRelation]] and other specific relations can't be imported
// because of a potential circular dependency, so we match a generic Catalyst
// [[MultiInstanceRelation]] instead.
case multiInstanceRelation: MultiInstanceRelation =>
validateRelation(multiInstanceRelation)
}
- ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds(
- operator.children.map(_.output)
- )
+
+ operator match {
+ case withCte: WithCTE =>
+ case _ =>
+ ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds(
+ operator.children.map(_.output)
+ )
+ }
}
private def validateWith(withCte: WithCTE): Unit = {
+ val knownCteDefIds = new HashSet[Long](withCte.cteDefs.length)
+
for (cteDef <- withCte.cteDefs) {
+ assert(
+ !knownCteDefIds.contains(cteDef.id),
+ s"Duplicate CTE definition id: ${cteDef.id}"
+ )
+
validate(cteDef)
+
+ knownCteDefIds.add(cteDef.id)
}
+
validate(withCte.plan)
}
private def validateCteRelationDef(cteRelationDef: CTERelationDef): Unit = {
validate(cteRelationDef.child)
-
- assert(
- !cteRelationDefIds.contains(cteRelationDef.id),
- s"Duplicate CTE relation def ID: $cteRelationDef"
- )
-
- cteRelationDefIds.add(cteRelationDef.id)
}
private def validateCteRelationRef(cteRelationRef: CTERelationRef): Unit = {
- assert(
- cteRelationDefIds.contains(cteRelationRef.cteId),
- s"CTE relation ref ID is not known: $cteRelationRef"
- )
-
handleOperatorOutput(cteRelationRef)
}
private def validateAggregate(aggregate: Aggregate): Unit = {
- attributeScopeStack.withNewScope {
+ attributeScopeStack.withNewScope() {
validate(aggregate.child)
expressionResolutionValidator.validateProjectList(aggregate.aggregateExpressions)
+ aggregate.groupingExpressions.foreach(expressionResolutionValidator.validate)
}
handleOperatorOutput(aggregate)
}
private def validateProject(project: Project): Unit = {
- attributeScopeStack.withNewScope {
+ attributeScopeStack.withNewScope() {
validate(project.child)
expressionResolutionValidator.validateProjectList(project.projectList)
}
@@ -207,6 +212,16 @@ class ResolutionValidator {
expressionResolutionValidator.validate(localLimit.limitExpr)
}
+ private def validateOffset(offset: Offset): Unit = {
+ validate(offset.child)
+ expressionResolutionValidator.validate(offset.offsetExpr)
+ }
+
+ private def validateTail(tail: Tail): Unit = {
+ validate(tail.child)
+ expressionResolutionValidator.validate(tail.limitExpr)
+ }
+
private def validateDistinct(distinct: Distinct): Unit = {
validate(distinct.child)
}
@@ -225,31 +240,57 @@ class ResolutionValidator {
handleOperatorOutput(relation)
}
- private def validateUnion(union: Union): Unit = {
- union.children.foreach(validate)
+ private def validateSetOperationLike(plan: LogicalPlan): Unit = {
+ plan.children.foreach(validate)
- assert(union.children.length > 1, "Union operator has to have at least 2 children")
- val firstChildOutput = union.children.head.output
- for (child <- union.children.tail) {
+ assert(
+ plan.children.length > 1,
+ s"${plan.nodeName} operator has to have at least 2 children"
+ )
+ val firstChildOutput = plan.children.head.output
+ for (child <- plan.children.tail) {
val childOutput = child.output
assert(
childOutput.length == firstChildOutput.length,
- s"Unexpected output length for Union child $child"
+ s"Unexpected output length for ${plan.nodeName} child $child"
)
- childOutput.zip(firstChildOutput).foreach {
- case (current, first) =>
- assert(
- DataType.equalsStructurally(current.dataType, first.dataType, ignoreNullability = true),
- s"Unexpected type of Union child attribute $current for $child"
- )
+ }
+
+ handleOperatorOutput(plan)
+ }
+
+ private def validateSort(sort: Sort): Unit = {
+ validate(sort.child)
+ for (sortOrder <- sort.order) {
+ expressionResolutionValidator.validate(sortOrder.child)
+ }
+ }
+
+ private def validateRepartition(repartition: Repartition): Unit = {
+ validate(repartition.child)
+ }
+
+ private def validateJoin(join: Join) = {
+ attributeScopeStack.withNewScope() {
+ attributeScopeStack.withNewScope() {
+ validate(join.left)
+ validate(join.right)
+ assert(join.left.outputSet.intersect(join.right.outputSet).isEmpty)
+ }
+
+ attributeScopeStack.overwriteCurrent(join.left.output ++ join.right.output)
+
+ join.condition match {
+ case Some(condition) => expressionResolutionValidator.validate(condition)
+ case None =>
}
}
- handleOperatorOutput(union)
+ handleOperatorOutput(join)
}
private def handleOperatorOutput(operator: LogicalPlan): Unit = {
- attributeScopeStack.overwriteTop(operator.output)
+ attributeScopeStack.overwriteCurrent(operator.output)
operator.output.foreach(attribute => {
assert(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedAggregateExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedAggregateExpressions.scala
new file mode 100644
index 0000000000000..9a3eba5da646c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedAggregateExpressions.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import java.util.HashSet
+
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
+
+/**
+ * [[ResolvedAggregateExpressions]] is used by the
+ * [[ExpressionResolver.resolveAggregateExpressions]] to return resolution results.
+ * - expressions: The resolved expressions. They are resolved using the
+ * `resolveExpressionTreeInOperator`.
+ * - resolvedExpressionsWithoutAggregates: List of resolved aggregate expressions that don't have
+ * [[AggregateExpression]]s in their subtrees.
+ * - hasAttributeOutsideOfAggregateExpressions: True if `expressions` list contains any attributes
+ * that are not under an [[AggregateExpression]].
+ * - hasStar: True if there is a star (`*`) in aggregate expressions list
+ * - expressionIndexesWithAggregateFunctions: Indices of expressions in aggregate expressions list
+ * that have aggregate functions in their subtrees.
+ */
+case class ResolvedAggregateExpressions(
+ expressions: Seq[NamedExpression],
+ resolvedExpressionsWithoutAggregates: Seq[NamedExpression],
+ hasAttributeOutsideOfAggregateExpressions: Boolean,
+ hasStar: Boolean,
+ expressionIndexesWithAggregateFunctions: HashSet[Int]
+)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala
index 7f3ca796a4949..3cb1dc0d454ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala
@@ -25,12 +25,9 @@ import org.apache.spark.sql.catalyst.expressions.NamedExpression
* `resolveExpressionTreeInOperator`.
* - hasAggregateExpressions: True if the resolved project list contains any aggregate
* expressions.
- * - hasAttributes: True if the resolved project list contains any attributes that are not under
- * an aggregate expression.
* - hasLateralColumnAlias: True if the resolved project list contains any lateral column aliases.
*/
case class ResolvedProjectList(
expressions: Seq[NamedExpression],
hasAggregateExpressions: Boolean,
- hasAttributes: Boolean,
hasLateralColumnAlias: Boolean)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedSubqueryExpressionPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedSubqueryExpressionPlan.scala
new file mode 100644
index 0000000000000..30e560b7afdd3
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedSubqueryExpressionPlan.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * The result of [[SubqueryExpression.plan]] resolution. This is used internally in
+ * [[SubqueryExpressionResolver]].
+ *
+ * @param plan The resolved plan of the subquery.
+ * @param output Plan output. We don't use [[LogicalPlan.output]] in the single-pass Analyzer,
+ * because this method is often recursive.
+ * @param outerExpressions The outer expressions that are references in the plan. [[OuterReference]]
+ * wrapper is stripped away. These can be either actual leaf [[AttributeReference]]s or
+ * [[AggregateExpression]]s with outer references inside.
+ */
+case class ResolvedSubqueryExpressionPlan(
+ plan: LogicalPlan,
+ output: Seq[Attribute],
+ outerExpressions: Seq[Expression])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala
index 511f36b4a7190..44a0ad52be924 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.EvaluateUnresolvedInlineTable
@@ -33,36 +31,12 @@ import org.apache.spark.sql.catalyst.analysis.{
UnresolvedRelation,
UnresolvedSubqueryColumnAliases
}
-import org.apache.spark.sql.catalyst.expressions.{
- Alias,
- Attribute,
- AttributeSet,
- Expression,
- NamedExpression
-}
-import org.apache.spark.sql.catalyst.plans.logical.{
- AnalysisHelper,
- CTERelationDef,
- CTERelationRef,
- Distinct,
- Filter,
- GlobalLimit,
- LeafNode,
- LocalLimit,
- LocalRelation,
- LogicalPlan,
- OneRowRelation,
- Project,
- SubqueryAlias,
- Union,
- UnresolvedWith,
- View,
- WithCTE
-}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.BooleanType
/**
* The Resolver implements a single-pass bottom-up analysis algorithm in the Catalyst.
@@ -90,18 +64,24 @@ class Resolver(
override val extensions: Seq[ResolverExtension] = Seq.empty,
metadataResolverExtensions: Seq[ResolverExtension] = Seq.empty)
extends LogicalPlanResolver
- with ResolvesOperatorChildren
with DelegatesResolutionToExtensions {
private val scopes = new NameScopeStack
private val cteRegistry = new CteRegistry
+ private val subqueryRegistry = new SubqueryRegistry
private val planLogger = new PlanLogger
+ private val identifierAndCteSubstitutor = new IdentifierAndCteSubstitutor
private val relationResolution = Resolver.createRelationResolution(catalogManager)
private val functionResolution = new FunctionResolution(catalogManager, relationResolution)
private val expressionResolver = new ExpressionResolver(this, functionResolution, planLogger)
+ private val aggregateResolver = new AggregateResolver(this, expressionResolver)
private val expressionIdAssigner = expressionResolver.getExpressionIdAssigner
private val projectResolver = new ProjectResolver(this, expressionResolver)
private val viewResolver = new ViewResolver(resolver = this, catalogManager = catalogManager)
- private val unionResolver = new UnionResolver(this, expressionResolver)
+ private val setOperationLikeResolver =
+ new SetOperationLikeResolver(this, expressionResolver)
+ private val filterResolver = new FilterResolver(this, expressionResolver)
+ private val sortResolver = new SortResolver(this, expressionResolver)
+ private val joinResolver = new JoinResolver(this, expressionResolver)
/**
* [[relationMetadataProvider]] is used to resolve metadata for relations. It's initialized with
@@ -121,36 +101,61 @@ class Resolver(
metadataResolverExtensions
)
+ /**
+ * Get [[NameScopeStack]] bound to the used [[Resolver]].
+ */
+ def getNameScopes: NameScopeStack = scopes
+
/**
* Get the [[CteRegistry]] which is a single instance per query resolution.
*/
- def getCteRegistry: CteRegistry = {
- cteRegistry
- }
+ def getCteRegistry: CteRegistry = cteRegistry
/**
- * This method is an analysis entry point. It resolves the metadata and invokes [[resolve]],
- * which does most of the analysis work.
+ * Get the [[SubqueryRegistry]] which is a single instance per query resolution.
+ */
+ def getSubqueryRegistry: SubqueryRegistry = subqueryRegistry
+
+ /**
+ * This method is a top-level analysis entry point:
+ * 1. Substitute IDENTIFIERs and CTEs in the `unresolvedPlan` using
+ * [[IdentifierAndCteSubstitutor]];
+ * 2. Resolve the metadata for the plan using [[MetadataResolver]]. When
+ * [[ANALYZER_SINGLE_PASS_RESOLVER_RELATION_BRIDGING_ENABLED]] is enabled, we need to
+ * re-instantiate the [[RelationMetadataProvider]] as [[View]] resolution context might have
+ * changed in the meantime;
+ * 3. Resolve the plan using [[resolve]].
+ *
+ * This method is called for the top-level query and each unresolved [[View]].
*/
def lookupMetadataAndResolve(
unresolvedPlan: LogicalPlan,
analyzerBridgeState: Option[AnalyzerBridgeState] = None): LogicalPlan = {
- planLogger.logPlanResolutionEvent(unresolvedPlan, "Lookup metadata and resolve")
+ planLogger.logPlanResolutionEvent(unresolvedPlan, "IDENTIFIER and CTE substitution")
+
+ val planAfterSubstitution = identifierAndCteSubstitutor.substitutePlan(unresolvedPlan)
+
+ planLogger.logPlanResolutionEvent(planAfterSubstitution, "Metadata lookup")
relationMetadataProvider = analyzerBridgeState match {
case Some(analyzerBridgeState) =>
new BridgedRelationMetadataProvider(
catalogManager,
relationResolution,
- analyzerBridgeState
+ analyzerBridgeState,
+ viewResolver
)
case None =>
relationMetadataProvider
}
- relationMetadataProvider.resolve(unresolvedPlan)
+ relationMetadataProvider.resolve(planAfterSubstitution)
+
+ planLogger.logPlanResolutionEvent(planAfterSubstitution, "Main resolution")
+
+ planAfterSubstitution.setTagValue(Resolver.TOP_LEVEL_OPERATOR, ())
- resolve(unresolvedPlan)
+ resolve(planAfterSubstitution)
}
/**
@@ -165,18 +170,24 @@ class Resolver(
* [[resolve]] will be called recursively during the unresolved plan traversal eventually
* producing a fully resolved plan or a descriptive error message.
*/
- override def resolve(unresolvedPlan: LogicalPlan): LogicalPlan =
+ override def resolve(unresolvedPlan: LogicalPlan): LogicalPlan = {
withOrigin(unresolvedPlan.origin) {
planLogger.logPlanResolutionEvent(unresolvedPlan, "Unresolved plan")
val resolvedPlan =
unresolvedPlan match {
+ case unresolvedJoin: Join =>
+ joinResolver.resolve(unresolvedJoin)
case unresolvedWith: UnresolvedWith =>
resolveWith(unresolvedWith)
+ case withCte: WithCTE =>
+ handleResolvedWithCte(withCte)
case unresolvedProject: Project =>
projectResolver.resolve(unresolvedProject)
+ case unresolvedAggregate: Aggregate =>
+ aggregateResolver.resolve(unresolvedAggregate)
case unresolvedFilter: Filter =>
- resolveFilter(unresolvedFilter)
+ filterResolver.resolve(unresolvedFilter)
case unresolvedSubqueryColumnAliases: UnresolvedSubqueryColumnAliases =>
resolveSubqueryColumnAliases(unresolvedSubqueryColumnAliases)
case unresolvedSubqueryAlias: SubqueryAlias =>
@@ -187,16 +198,26 @@ class Resolver(
resolveGlobalLimit(unresolvedGlobalLimit)
case unresolvedLocalLimit: LocalLimit =>
resolveLocalLimit(unresolvedLocalLimit)
+ case unresolvedOffset: Offset =>
+ resolveOffset(unresolvedOffset)
+ case unresolvedTail: Tail =>
+ resolveTail(unresolvedTail)
case unresolvedDistinct: Distinct =>
resolveDistinct(unresolvedDistinct)
case unresolvedRelation: UnresolvedRelation =>
resolveRelation(unresolvedRelation)
- case unresolvedCteRelationDef: CTERelationDef =>
- resolveCteRelationDef(unresolvedCteRelationDef)
+ case unresolvedCteRelationRef: UnresolvedCteRelationRef =>
+ resolveCteRelationRef(unresolvedCteRelationRef)
+ case cteRelationDef: CTERelationDef =>
+ handleResolvedCteRelationDef(cteRelationDef)
+ case cteRelationRef: CTERelationRef =>
+ handleLeafOperator(cteRelationRef)
case unresolvedInlineTable: UnresolvedInlineTable =>
resolveInlineTable(unresolvedInlineTable)
- case unresolvedUnion: Union =>
- unionResolver.resolve(unresolvedUnion)
+ case unresolvedSetOperationLike @ (_: Union | _: SetOperation) =>
+ setOperationLikeResolver.resolve(unresolvedSetOperationLike)
+ case unresolvedSort: Sort =>
+ sortResolver.resolve(unresolvedSort)
// See the reason why we have to match both [[LocalRelation]] and [[ResolvedInlineTable]]
// in the [[resolveInlineTable]] scaladoc
case resolvedInlineTable: ResolvedInlineTable =>
@@ -205,34 +226,25 @@ class Resolver(
handleLeafOperator(localRelation)
case unresolvedOneRowRelation: OneRowRelation =>
handleLeafOperator(unresolvedOneRowRelation)
+ case unresolvedRange: Range =>
+ handleLeafOperator(unresolvedRange)
case _ =>
tryDelegateResolutionToExtension(unresolvedPlan).getOrElse {
handleUnmatchedOperator(unresolvedPlan)
}
}
- if (resolvedPlan.children.nonEmpty) {
- val missingInput = resolvedPlan.missingInput
- if (missingInput.nonEmpty) {
- withPosition(unresolvedPlan) {
- throwMissingAttributesError(resolvedPlan, missingInput)
- }
- }
- }
-
- if (!resolvedPlan.resolved) {
- throwSinglePassFailedToResolveOperator(resolvedPlan)
+ withPosition(unresolvedPlan) {
+ validateResolvedOperatorGenerically(resolvedPlan)
}
planLogger.logPlanResolution(unresolvedPlan, resolvedPlan)
- preservePlanIdTag(unresolvedPlan, resolvedPlan)
- }
+ resolvedPlan.copyTagsFrom(unresolvedPlan)
- /**
- * Get [[NameScopeStack]] bound to the used [[Resolver]].
- */
- def getNameScopes: NameScopeStack = scopes
+ resolvedPlan
+ }
+ }
/**
* [[UnresolvedWith]] contains a list of unresolved CTE definitions, which are represented by
@@ -247,19 +259,13 @@ class Resolver(
* See [[CteScope]] scaladoc for all the details on how CTEs are resolved.
*/
private def resolveWith(unresolvedWith: UnresolvedWith): LogicalPlan = {
- val childOutputs = new ArrayBuffer[Seq[Attribute]]
-
for (cteRelation <- unresolvedWith.cteRelations) {
val (cteName, ctePlan) = cteRelation
- val resolvedCtePlan = scopes.withNewScope {
+ val resolvedCtePlan = scopes.withNewScope() {
expressionIdAssigner.withNewMapping() {
cteRegistry.withNewScope() {
- val resolvedCtePlan = resolve(ctePlan)
-
- childOutputs.append(scopes.top.output)
-
- resolvedCtePlan
+ resolve(ctePlan)
}
}
}
@@ -271,35 +277,31 @@ class Resolver(
resolve(unresolvedWith.child)
}
- childOutputs.append(scopes.top.output)
-
- ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds(childOutputs.toSeq)
-
- if (cteRegistry.currentScope.isRoot) {
- WithCTE(resolvedChild, cteRegistry.currentScope.getKnownCtes)
- } else {
- resolvedChild
- }
+ cteRegistry.currentScope.tryPutWithCTE(
+ unresolvedOperator = unresolvedWith,
+ resolvedOperator = resolvedChild
+ )
}
/**
- * [[Filter]] has a single child and a single condition and we resolve them in this respective
- * order.
+ * We may meet resolved [[WithCTE]] while traversing partially resolved trees in DataFrame
+ * programs. In that case we simply recurse into the CTE definitions and the main plan under
+ * new scopes and mappings.
*/
- private def resolveFilter(unresolvedFilter: Filter): LogicalPlan = {
- val resolvedChild = resolve(unresolvedFilter.child)
- val resolvedCondition =
- expressionResolver
- .resolveExpressionTreeInOperator(unresolvedFilter.condition, unresolvedFilter)
-
- val resolvedFilter = Filter(resolvedCondition, resolvedChild)
- if (resolvedFilter.condition.dataType != BooleanType) {
- withPosition(unresolvedFilter) {
- throwDatatypeMismatchFilterNotBoolean(resolvedFilter)
+ private def handleResolvedWithCte(withCte: WithCTE): LogicalPlan = {
+ val resolvedCteDefs = withCte.cteDefs.map { cteDef =>
+ scopes.withNewScope() {
+ expressionIdAssigner.withNewMapping() {
+ cteRegistry.withNewScope() {
+ resolve(cteDef).asInstanceOf[CTERelationDef]
+ }
+ }
}
}
- resolvedFilter
+ val resolvedPlan = resolve(withCte.plan)
+
+ WithCTE(plan = resolvedPlan, cteDefs = resolvedCteDefs)
}
/**
@@ -315,21 +317,20 @@ class Resolver(
unresolvedSubqueryColumnAliases: UnresolvedSubqueryColumnAliases): LogicalPlan = {
val resolvedChild = resolve(unresolvedSubqueryColumnAliases.child)
- if (unresolvedSubqueryColumnAliases.outputColumnNames.size != scopes.top.output.size) {
- withPosition(unresolvedSubqueryColumnAliases) {
- throw QueryCompilationErrors.aliasNumberNotMatchColumnNumberError(
- unresolvedSubqueryColumnAliases.outputColumnNames.size,
- scopes.top.output.size,
- unresolvedSubqueryColumnAliases
- )
- }
+ if (unresolvedSubqueryColumnAliases.outputColumnNames.size != scopes.current.output.size) {
+ throw QueryCompilationErrors.aliasNumberNotMatchColumnNumberError(
+ unresolvedSubqueryColumnAliases.outputColumnNames.size,
+ scopes.current.output.size,
+ unresolvedSubqueryColumnAliases
+ )
}
- val projectList = scopes.top.output.zip(unresolvedSubqueryColumnAliases.outputColumnNames).map {
- case (attr, columnName) => expressionIdAssigner.mapExpression(Alias(attr, columnName)())
- }
+ val projectList =
+ scopes.current.output.zip(unresolvedSubqueryColumnAliases.outputColumnNames).map {
+ case (attr, columnName) => expressionIdAssigner.mapExpression(Alias(attr, columnName)())
+ }
- overwriteTopScope(unresolvedSubqueryColumnAliases, projectList.map(_.toAttribute))
+ scopes.overwriteCurrent(output = Some(projectList.map(_.toAttribute)))
Project(projectList = projectList, child = resolvedChild)
}
@@ -338,16 +339,30 @@ class Resolver(
* [[SubqueryAlias]] has a single child and an identifier. We need to resolve the child and update
* the scope with the output, since upper expressions can reference [[SubqueryAlias]]es output by
* its identifier.
+ *
+ * Hidden output is reset when [[SubqueryAlias]] is reached during tree traversal. This has to be
+ * done because upper SQL projection (or DataFrame) cannot look down into the previous
+ * [[SubqueryAlias]]'s hidden output.
+ * Examples (both will throw `UNRESOLVED_COLUMN` exception):
+ *
+ * 1. SQL
+ * {{{
+ * -- Hidden output will be reset at [[SubqueryAlias]] and therefore both `output` and
+ * -- `hiddenOutput` will be [`col2`]. Because of that, `UNRESOLVED_COLUMN` is thrown
+ * -- when resolving `col1`
+ * SELECT col1 FROM (SELECT col2 FROM VALUES (1, 2));
+ * }}}
+ *
+ * 2. DataFrame
+ * {{{ spark.sql("SELECT * FROM VALUES (1, 2)").select("col1").as("q1").select("col2"); }}}
*/
private def resolveSubqueryAlias(unresolvedSubqueryAlias: SubqueryAlias): LogicalPlan = {
val resolvedSubqueryAlias =
unresolvedSubqueryAlias.copy(child = resolve(unresolvedSubqueryAlias.child))
val qualifier = resolvedSubqueryAlias.identifier.qualifier :+ resolvedSubqueryAlias.alias
- overwriteTopScope(
- unresolvedSubqueryAlias,
- scopes.top.output.map(attribute => attribute.withQualifier(qualifier))
- )
+ val output = scopes.current.output.map(attribute => attribute.withQualifier(qualifier))
+ scopes.overwriteCurrent(output = Some(output), hiddenOutput = Some(output))
resolvedSubqueryAlias
}
@@ -359,12 +374,10 @@ class Resolver(
private def resolveGlobalLimit(unresolvedGlobalLimit: GlobalLimit): LogicalPlan = {
val resolvedChild = resolve(unresolvedGlobalLimit.child)
- val resolvedLimitExpr = withPosition(unresolvedGlobalLimit) {
- expressionResolver.resolveLimitExpression(
- unresolvedGlobalLimit.limitExpr,
- unresolvedGlobalLimit
- )
- }
+ val resolvedLimitExpr = expressionResolver.resolveLimitLikeExpression(
+ unresolvedGlobalLimit.limitExpr,
+ unresolvedGlobalLimit.copy(child = resolvedChild)
+ )
GlobalLimit(resolvedLimitExpr, resolvedChild)
}
@@ -376,21 +389,53 @@ class Resolver(
private def resolveLocalLimit(unresolvedLocalLimit: LocalLimit): LogicalPlan = {
val resolvedChild = resolve(unresolvedLocalLimit.child)
- val resolvedLimitExpr = withPosition(unresolvedLocalLimit) {
- expressionResolver.resolveLimitExpression(
- unresolvedLocalLimit.limitExpr,
- unresolvedLocalLimit
- )
- }
+ val resolvedLimitExpr = expressionResolver.resolveLimitLikeExpression(
+ unresolvedLocalLimit.limitExpr,
+ unresolvedLocalLimit.copy(child = resolvedChild)
+ )
LocalLimit(resolvedLimitExpr, resolvedChild)
}
+ /**
+ * Resolve [[Offset]]. We have to resolve its child and resolve and validate its offset
+ * expression.
+ */
+ private def resolveOffset(unresolvedOffset: Offset): LogicalPlan = {
+ val resolvedChild = resolve(unresolvedOffset.child)
+
+ val resolvedOffsetExpr = expressionResolver.resolveLimitLikeExpression(
+ unresolvedOffset.offsetExpr,
+ unresolvedOffset.copy(child = resolvedChild)
+ )
+
+ Offset(resolvedOffsetExpr, resolvedChild)
+ }
+
+ /**
+ * Resolve [[Tail]]. We have to resolve its child and resolve and validate its limit
+ * expression.
+ */
+ private def resolveTail(unresolvedTail: Tail): LogicalPlan = {
+ val resolvedChild = resolve(unresolvedTail.child)
+
+ val resolvedTailExpr = expressionResolver.resolveLimitLikeExpression(
+ unresolvedTail.limitExpr,
+ unresolvedTail.copy(child = resolvedChild)
+ )
+
+ Tail(resolvedTailExpr, resolvedChild)
+ }
+
/**
* [[Distinct]] operator doesn't require any special resolution.
+ *
+ * Hidden output is reset when [[Distinct]] is reached during tree traversal.
*/
private def resolveDistinct(unresolvedDistinct: Distinct): LogicalPlan = {
- withResolvedChildren(unresolvedDistinct, resolve)
+ val resolvedDistinct = unresolvedDistinct.copy(child = resolve(unresolvedDistinct.child))
+ scopes.overwriteCurrent(hiddenOutput = Some(scopes.current.output))
+ resolvedDistinct
}
/**
@@ -400,50 +445,53 @@ class Resolver(
* - Resolve it further, usually using extensions, like [[DataSourceResolver]]
*/
private def resolveRelation(unresolvedRelation: UnresolvedRelation): LogicalPlan = {
- withPosition(unresolvedRelation) {
- viewResolver.withSourceUnresolvedRelation(unresolvedRelation) {
- val maybeResolvedRelation = cteRegistry.resolveCteName(unresolvedRelation.name).orElse {
- relationMetadataProvider.getRelationWithResolvedMetadata(unresolvedRelation)
- }
-
- val resolvedRelation = maybeResolvedRelation match {
- case Some(cteRelationDef: CTERelationDef) =>
- planLogger.logPlanResolutionEvent(cteRelationDef, "CTE definition resolved")
-
- SubqueryAlias(identifier = unresolvedRelation.name, child = cteRelationDef)
- case Some(relationsWithResolvedMetadata) =>
- planLogger.logPlanResolutionEvent(
- relationsWithResolvedMetadata,
- "Relation metadata retrieved"
- )
-
- relationsWithResolvedMetadata
- case None =>
- unresolvedRelation.tableNotFound(unresolvedRelation.multipartIdentifier)
- }
-
- resolve(resolvedRelation)
+ viewResolver.withSourceUnresolvedRelation(unresolvedRelation) {
+ val maybeResolvedRelation =
+ relationMetadataProvider.getRelationWithResolvedMetadata(unresolvedRelation)
+
+ val resolvedRelation = maybeResolvedRelation match {
+ case Some(relationsWithResolvedMetadata) =>
+ planLogger.logPlanResolutionEvent(
+ relationsWithResolvedMetadata,
+ "Relation metadata retrieved"
+ )
+
+ relationsWithResolvedMetadata
+ case None =>
+ unresolvedRelation.tableNotFound(unresolvedRelation.multipartIdentifier)
}
+
+ resolve(resolvedRelation)
}
}
/**
- * Resolve [[CTERelationDef]] by replacing it with [[CTERelationRef]] with the same ID so that
- * the Optimizer can make a decision whether to inline the definition or not.
- *
- * [[CTERelationDef.statsOpt]] is filled by the Optimizer.
+ * Resolve the [[UnresolvedCteRelationRef]] which was previously introduved by the
+ * [[IdentifierAndCteSubstitutor]].
*/
- private def resolveCteRelationDef(unresolvedCteRelationDef: CTERelationDef): LogicalPlan = {
- val cteRelationRef = CTERelationRef(
- cteId = unresolvedCteRelationDef.id,
- _resolved = true,
- isStreaming = unresolvedCteRelationDef.isStreaming,
- output = unresolvedCteRelationDef.output,
- recursive = false,
- maxRows = unresolvedCteRelationDef.maxRows
- )
+ private def resolveCteRelationRef(
+ unresolvedCteRelationRef: UnresolvedCteRelationRef): LogicalPlan = {
+ val cteRelationRef = cteRegistry.resolveCteName(unresolvedCteRelationRef.name) match {
+ case Some(cteRelationDef) =>
+ planLogger.logPlanResolutionEvent(cteRelationDef, "CTE definition resolved")
+
+ createCteRelationRef(
+ name = unresolvedCteRelationRef.name,
+ cteRelationDef = cteRelationDef
+ )
+ case None =>
+ unresolvedCteRelationRef.tableNotFound(Seq(unresolvedCteRelationRef.name))
+ }
+
+ resolve(cteRelationRef)
+ }
- handleLeafOperator(cteRelationRef)
+ /**
+ * We may meet resolved [[CTERelationRef]] while traversing partially resolved trees in DataFrame
+ * programs. In that case we simply recurse into the child plan.
+ */
+ private def handleResolvedCteRelationDef(cteRelationDef: CTERelationDef): LogicalPlan = {
+ cteRelationDef.copy(child = resolve(cteRelationDef.child))
}
/**
@@ -470,23 +518,7 @@ class Resolver(
val resolvedRelation = EvaluateUnresolvedInlineTable
.evaluateUnresolvedInlineTable(withResolvedExpressions)
- withPosition(unresolvedInlineTable) {
- resolve(resolvedRelation)
- }
- }
-
- /**
- * Preserve `PLAN_ID_TAG` which is used for DataFrame column resolution in Spark Connect.
- */
- private def preservePlanIdTag(
- unresolvedOperator: LogicalPlan,
- resolvedOperator: LogicalPlan): LogicalPlan = {
- unresolvedOperator.getTagValue(LogicalPlan.PLAN_ID_TAG) match {
- case Some(planIdTag) =>
- resolvedOperator.setTagValue(LogicalPlan.PLAN_ID_TAG, planIdTag)
- case None =>
- }
- resolvedOperator
+ resolve(resolvedRelation)
}
private def tryDelegateResolutionToExtension(
@@ -507,57 +539,70 @@ class Resolver(
* `leafOperator`'s output attribute IDs. We don't reassign expression IDs in the leftmost
* branch, see [[ExpressionIdAssigner]] class doc for more details.
* [[CTERelationRef]]'s output can always be reassigned.
- * - Overwrite the current [[NameScope]] with remapped output attributes. It's OK to call
+ * - Overwrite the current [[NameScope]] with remapped output attributes (both
+ * [[NameScope.output]] and [[NameScope.hiddenOutput]] are updated). It's OK to call
* `output` on a [[LeafNode]], because it's not recursive (this call fits the single-pass
* framework).
*/
private def handleLeafOperator(leafOperator: LeafNode): LogicalPlan = {
val leafOperatorWithAssignedExpressionIds = leafOperator match {
- case leafOperator
- if expressionIdAssigner.isLeftmostBranch && !leafOperator.isInstanceOf[CTERelationRef] =>
- expressionIdAssigner.createMapping(newOutput = leafOperator.output)
+ case leafOperator if expressionIdAssigner.shouldPreserveLeafOperatorIds(leafOperator) =>
+ expressionIdAssigner.createMappingForLeafOperator(newOperator = leafOperator)
+
leafOperator
- /**
- * [[InMemoryRelation.statsOfPlanToCache]] is mutable and does not get copied during normal
- * [[transformExpressionsUp]]. The easiest way to correctly copy it is via [[newInstance]]
- * call.
- *
- * We match [[MultiInstanceRelation]] to avoid a cyclic import between [[catalyst]] and
- * [[execution]].
- */
- case originalRelation: MultiInstanceRelation =>
- val newRelation = originalRelation.newInstance()
-
- expressionIdAssigner.createMapping(
- newOutput = newRelation.output,
- oldOutput = Some(originalRelation.output)
- )
+ case originalLeafOperator =>
+ val newLeafOperator = originalLeafOperator match {
- newRelation
- case _ =>
- expressionIdAssigner.createMapping()
-
- AnalysisHelper.allowInvokingTransformsInAnalyzer {
- leafOperator.transformExpressionsUp {
- case expression: NamedExpression =>
- val newExpression = expressionIdAssigner.mapExpression(expression)
- if (newExpression.eq(expression)) {
- throw SparkException.internalError(
- s"Leaf operator expression ID was not reassigned. Expression: $expression, " +
- s"leaf operator: $leafOperator"
- )
+ /**
+ * [[InMemoryRelation.statsOfPlanToCache]] is mutable and does not get copied during
+ * [[transformExpressionsUp]]. The easiest way to correctly copy it is via
+ * [[newInstance]] call.
+ *
+ * We match [[MultiInstanceRelation]] to avoid a cyclic import between [[catalyst]] and
+ * [[execution]].
+ */
+ case originalRelation: MultiInstanceRelation =>
+ originalRelation.newInstance().asInstanceOf[LeafNode]
+
+ case _ =>
+ AnalysisHelper
+ .allowInvokingTransformsInAnalyzer {
+ leafOperator.transformExpressions {
+ case attribute: Attribute => attribute.newInstance()
+ }
}
- newExpression
- }
+ .asInstanceOf[LeafNode]
}
+
+ expressionIdAssigner.createMappingForLeafOperator(
+ newOperator = newLeafOperator,
+ oldOperator = Some(originalLeafOperator)
+ )
+
+ newLeafOperator
}
- overwriteTopScope(leafOperator, leafOperatorWithAssignedExpressionIds.output)
+ val output = leafOperatorWithAssignedExpressionIds.output
+ scopes.overwriteCurrent(output = Some(output), hiddenOutput = Some(output))
leafOperatorWithAssignedExpressionIds
}
+ private def createCteRelationRef(name: String, cteRelationDef: CTERelationDef): LogicalPlan = {
+ SubqueryAlias(
+ identifier = name,
+ child = CTERelationRef(
+ cteId = cteRelationDef.id,
+ _resolved = true,
+ isStreaming = cteRelationDef.isStreaming,
+ output = cteRelationDef.output,
+ recursive = false,
+ maxRows = cteRelationDef.maxRows
+ )
+ )
+ }
+
/**
* Check if the unresolved operator is explicitly unsupported and throw
* [[ExplicitlyUnsupportedResolverFeature]] in that case. Otherwise, throw
@@ -578,15 +623,23 @@ class Resolver(
.withPosition(unresolvedOperator.origin)
}
- private def throwDatatypeMismatchFilterNotBoolean(filter: Filter): Nothing =
- throw new AnalysisException(
- errorClass = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN",
- messageParameters = Map(
- "sqlExpr" -> makeCommaSeparatedExpressionString(filter.expressions),
- "filter" -> toSQLExpr(filter.condition),
- "type" -> toSQLType(filter.condition.dataType)
- )
- )
+ private def validateResolvedOperatorGenerically(resolvedOperator: LogicalPlan): Unit = {
+ if (!resolvedOperator.resolved) {
+ throwSinglePassFailedToResolveOperator(resolvedOperator)
+ }
+
+ if (resolvedOperator.children.nonEmpty) {
+ val missingInput = resolvedOperator.missingInput
+ if (missingInput.nonEmpty) {
+ throwMissingAttributesError(resolvedOperator, missingInput)
+ }
+ }
+
+ val invalidExpressions = expressionResolver.getLastInvalidExpressionsInTheContextOfOperator
+ if (invalidExpressions.nonEmpty) {
+ throwUnsupportedExprForOperator(invalidExpressions)
+ }
+ }
private def throwMissingAttributesError(
operator: LogicalPlan,
@@ -631,21 +684,27 @@ class Resolver(
summary = operator.origin.context.summary()
)
- private def makeCommaSeparatedExpressionString(expressions: Seq[Expression]): String = {
- expressions.map(toSQLExpr).mkString(", ")
+ private def throwUnsupportedExprForOperator(invalidExpressions: Seq[Expression]): Nothing = {
+ throw new AnalysisException(
+ errorClass = "UNSUPPORTED_EXPR_FOR_OPERATOR",
+ messageParameters = Map(
+ "invalidExprSqls" -> makeCommaSeparatedExpressionString(invalidExpressions)
+ )
+ )
}
- private def overwriteTopScope(
- sourceUnresolvedOperator: LogicalPlan,
- output: Seq[Attribute]): Unit = {
- withPosition(sourceUnresolvedOperator) {
- scopes.overwriteTop(output)
- }
+ private def makeCommaSeparatedExpressionString(expressions: Seq[Expression]): String = {
+ expressions.map(toSQLExpr).mkString(", ")
}
}
object Resolver {
+ /**
+ * Marks the operator as the top-most operator in a query or a view.
+ */
+ val TOP_LEVEL_OPERATOR = TreeNodeTag[Unit]("top_level_operator")
+
/**
* Create a new instance of the [[RelationResolution]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
index 833f50a5203d9..334ffcdf2f638 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
@@ -19,7 +19,11 @@ package org.apache.spark.sql.catalyst.analysis.resolver
import java.util.Locale
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, SqlScriptingLocalVariableManager}
+import org.apache.spark.sql.catalyst.{
+ FunctionIdentifier,
+ SQLConfHelper,
+ SqlScriptingLocalVariableManager
+}
import org.apache.spark.sql.catalyst.analysis.{
FunctionRegistry,
GetViewColumnByNameAndOrdinal,
@@ -32,34 +36,10 @@ import org.apache.spark.sql.catalyst.analysis.{
UnresolvedStar,
UnresolvedSubqueryColumnAliases
}
-import org.apache.spark.sql.catalyst.expressions.{
- Alias,
- AttributeReference,
- BinaryArithmetic,
- Cast,
- ConditionalExpression,
- CreateNamedStruct,
- Expression,
- Literal,
- Predicate,
- SubqueryExpression,
- UpCast
-}
-import org.apache.spark.sql.catalyst.plans.logical.{
- Distinct,
- Filter,
- GlobalLimit,
- LocalLimit,
- LocalRelation,
- LogicalPlan,
- OneRowRelation,
- Project,
- SubqueryAlias,
- Union,
- UnresolvedWith,
- View
-}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.connector.catalog.CatalogManager
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode
@@ -75,46 +55,93 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
/**
* Check the top level operator of the parsed operator.
*/
- def apply(operator: LogicalPlan): Boolean =
- checkConfValues() && checkVariables() && checkOperator(operator)
+ def apply(operator: LogicalPlan): Boolean = {
+ val unsupportedConf = detectUnsupportedConf()
+ if (unsupportedConf.isDefined) {
+ tryThrowUnsupportedSinglePassAnalyzerFeature(s"configuration: ${unsupportedConf.get}")
+ }
+
+ val areTempVariablesSupported = checkTempVariables()
+ if (!areTempVariablesSupported) {
+ tryThrowUnsupportedSinglePassAnalyzerFeature("temp variables")
+ }
+
+ val areScriptingVariablesSupported = checkScriptingVariables()
+ if (!areScriptingVariablesSupported) {
+ tryThrowUnsupportedSinglePassAnalyzerFeature("scripting variables")
+ }
+
+ !unsupportedConf.isDefined &&
+ areTempVariablesSupported &&
+ areScriptingVariablesSupported &&
+ checkOperator(operator)
+ }
/**
* Check if all the operators are supported. For implemented ones, recursively check
* their children. For unimplemented ones, return false.
*/
- private def checkOperator(operator: LogicalPlan): Boolean = operator match {
- case unresolvedWith: UnresolvedWith =>
- checkUnresolvedWith(unresolvedWith)
- case project: Project =>
- checkProject(project)
- case filter: Filter =>
- checkFilter(filter)
- case unresolvedSubqueryColumnAliases: UnresolvedSubqueryColumnAliases =>
- checkUnresolvedSubqueryColumnAliases(unresolvedSubqueryColumnAliases)
- case subqueryAlias: SubqueryAlias =>
- checkSubqueryAlias(subqueryAlias)
- case globalLimit: GlobalLimit =>
- checkGlobalLimit(globalLimit)
- case localLimit: LocalLimit =>
- checkLocalLimit(localLimit)
- case distinct: Distinct =>
- checkDistinct(distinct)
- case view: View =>
- checkView(view)
- case unresolvedRelation: UnresolvedRelation =>
- checkUnresolvedRelation(unresolvedRelation)
- case unresolvedInlineTable: UnresolvedInlineTable =>
- checkUnresolvedInlineTable(unresolvedInlineTable)
- case resolvedInlineTable: ResolvedInlineTable =>
- checkResolvedInlineTable(resolvedInlineTable)
- case localRelation: LocalRelation =>
- checkLocalRelation(localRelation)
- case oneRowRelation: OneRowRelation =>
- checkOneRowRelation(oneRowRelation)
- case union: Union =>
- checkUnion(union)
- case _ =>
- false
+ private def checkOperator(operator: LogicalPlan): Boolean = {
+ val isSupported = operator match {
+ case unresolvedWith: UnresolvedWith =>
+ checkUnresolvedWith(unresolvedWith)
+ case withCte: WithCTE =>
+ checkWithCte(withCte)
+ case project: Project =>
+ checkProject(project)
+ case aggregate: Aggregate =>
+ checkAggregate(aggregate)
+ case filter: Filter =>
+ checkFilter(filter)
+ case join: Join =>
+ checkJoin(join)
+ case unresolvedSubqueryColumnAliases: UnresolvedSubqueryColumnAliases =>
+ checkUnresolvedSubqueryColumnAliases(unresolvedSubqueryColumnAliases)
+ case subqueryAlias: SubqueryAlias =>
+ checkSubqueryAlias(subqueryAlias)
+ case globalLimit: GlobalLimit =>
+ checkGlobalLimit(globalLimit)
+ case localLimit: LocalLimit =>
+ checkLocalLimit(localLimit)
+ case offset: Offset =>
+ checkOffset(offset)
+ case tail: Tail =>
+ checkTail(tail)
+ case distinct: Distinct =>
+ checkDistinct(distinct)
+ case view: View =>
+ checkView(view)
+ case unresolvedRelation: UnresolvedRelation =>
+ checkUnresolvedRelation(unresolvedRelation)
+ case unresolvedInlineTable: UnresolvedInlineTable =>
+ checkUnresolvedInlineTable(unresolvedInlineTable)
+ case resolvedInlineTable: ResolvedInlineTable =>
+ checkResolvedInlineTable(resolvedInlineTable)
+ case localRelation: LocalRelation =>
+ checkLocalRelation(localRelation)
+ case range: Range =>
+ checkRange(range)
+ case oneRowRelation: OneRowRelation =>
+ checkOneRowRelation(oneRowRelation)
+ case cteRelationDef: CTERelationDef =>
+ checkCteRelationDef(cteRelationDef)
+ case cteRelationRef: CTERelationRef =>
+ checkCteRelationRef(cteRelationRef)
+ case union: Union =>
+ checkUnion(union)
+ case setOperation: SetOperation =>
+ checkSetOperation(setOperation)
+ case sort: Sort =>
+ checkSort(sort)
+ case _ =>
+ false
+ }
+
+ if (!isSupported) {
+ tryThrowUnsupportedSinglePassAnalyzerFeature(operator)
+ }
+
+ isSupported
}
/**
@@ -122,11 +149,9 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
* For LeafNode types, we return true or false. For other ones, check their children.
*/
private def checkExpression(expression: Expression): Boolean = {
- expression match {
+ val isSupported = expression match {
case alias: Alias =>
checkAlias(alias)
- case unresolvedBinaryArithmetic: BinaryArithmetic =>
- checkUnresolvedBinaryArithmetic(unresolvedBinaryArithmetic)
case unresolvedConditionalExpression: ConditionalExpression =>
checkUnresolvedConditionalExpression(unresolvedConditionalExpression)
case unresolvedCast: Cast =>
@@ -139,10 +164,16 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
checkUnresolvedAlias(unresolvedAlias)
case unresolvedAttribute: UnresolvedAttribute =>
checkUnresolvedAttribute(unresolvedAttribute)
- case unresolvedPredicate: Predicate =>
- checkUnresolvedPredicate(unresolvedPredicate)
case literal: Literal =>
checkLiteral(literal)
+ case unresolvedPredicate: Predicate =>
+ checkUnresolvedPredicate(unresolvedPredicate)
+ case scalarSubquery: ScalarSubquery =>
+ checkScalarSubquery(scalarSubquery)
+ case listQuery: ListQuery =>
+ checkListQuery(listQuery)
+ case outerReference: OuterReference =>
+ checkOuterReference(outerReference)
case attributeReference: AttributeReference =>
checkAttributeReference(attributeReference)
case createNamedStruct: CreateNamedStruct =>
@@ -151,9 +182,17 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
checkUnresolvedFunction(unresolvedFunction)
case getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal =>
checkGetViewColumnBynameAndOrdinal(getViewColumnByNameAndOrdinal)
+ case expression if isGenerallySupportedExpression(expression) =>
+ expression.children.forall(checkExpression)
case _ =>
false
}
+
+ if (!isSupported) {
+ tryThrowUnsupportedSinglePassAnalyzerFeature(expression)
+ }
+
+ isSupported
}
private def checkUnresolvedWith(unresolvedWith: UnresolvedWith) = {
@@ -163,10 +202,29 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
} && checkOperator(unresolvedWith.child)
}
+ private def checkWithCte(withCte: WithCTE) = {
+ withCte.children.forall(checkOperator)
+ }
+
private def checkProject(project: Project) = {
checkOperator(project.child) && project.projectList.forall(checkExpression)
}
+ private def checkAggregate(aggregate: Aggregate) = {
+ checkOperator(aggregate.child) &&
+ aggregate.groupingExpressions.forall(checkExpression) &&
+ aggregate.aggregateExpressions.forall(checkExpression)
+ }
+
+ private def checkJoin(join: Join) = {
+ checkOperator(join.left) && checkOperator(join.right) && {
+ join.condition match {
+ case Some(condition) => checkExpression(condition)
+ case None => true
+ }
+ }
+ }
+
private def checkFilter(unresolvedFilter: Filter) =
checkOperator(unresolvedFilter.child) && checkExpression(unresolvedFilter.condition)
@@ -183,6 +241,12 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
private def checkLocalLimit(localLimit: LocalLimit) =
checkOperator(localLimit.child) && checkExpression(localLimit.limitExpr)
+ private def checkOffset(offset: Offset) =
+ checkOperator(offset.child) && checkExpression(offset.offsetExpr)
+
+ private def checkTail(tail: Tail) =
+ checkOperator(tail.child) && checkExpression(tail.limitExpr)
+
private def checkDistinct(distinct: Distinct) =
checkOperator(distinct.child)
@@ -201,16 +265,29 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
private def checkLocalRelation(localRelation: LocalRelation) =
localRelation.output.forall(checkExpression)
+ private def checkRange(range: Range) = true
+
private def checkUnion(union: Union) =
!union.byName && !union.allowMissingCol && union.children.forall(checkOperator)
+ private def checkSetOperation(setOperation: SetOperation) =
+ setOperation.children.forall(checkOperator)
+
+ private def checkSort(sort: Sort) = {
+ checkOperator(sort.child) && sort.order.forall(
+ sortOrder => checkExpression(sortOrder.child)
+ )
+ }
+
private def checkOneRowRelation(oneRowRelation: OneRowRelation) = true
- private def checkAlias(alias: Alias) = checkExpression(alias.child)
+ private def checkCteRelationDef(cteRelationDef: CTERelationDef) = {
+ checkOperator(cteRelationDef.child)
+ }
- private def checkUnresolvedBinaryArithmetic(unresolvedBinaryArithmetic: BinaryArithmetic) =
- checkExpression(unresolvedBinaryArithmetic.left) &&
- checkExpression(unresolvedBinaryArithmetic.right)
+ private def checkCteRelationRef(cteRelationRef: CTERelationRef) = true
+
+ private def checkAlias(alias: Alias) = checkExpression(alias.child)
private def checkUnresolvedConditionalExpression(
unresolvedConditionalExpression: ConditionalExpression) =
@@ -229,12 +306,13 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
!ResolverGuard.UNSUPPORTED_ATTRIBUTE_NAMES.contains(unresolvedAttribute.nameParts.head) &&
!unresolvedAttribute.getTagValue(LogicalPlan.PLAN_ID_TAG).isDefined
- private def checkUnresolvedPredicate(unresolvedPredicate: Predicate) = {
- unresolvedPredicate match {
- case _: SubqueryExpression => false
- case other =>
- other.children.forall(checkExpression)
- }
+ private def checkUnresolvedPredicate(unresolvedPredicate: Predicate) = unresolvedPredicate match {
+ case inSubquery: InSubquery =>
+ checkInSubquery(inSubquery)
+ case exists: Exists =>
+ checkExists(exists)
+ case _ =>
+ unresolvedPredicate.children.forall(checkExpression)
}
private def checkAttributeReference(attributeReference: AttributeReference) = true
@@ -253,22 +331,128 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
private def checkLiteral(literal: Literal) = true
+ private def checkScalarSubquery(scalarSubquery: ScalarSubquery) =
+ checkOperator(scalarSubquery.plan)
+
+ private def checkInSubquery(inSubquery: InSubquery) =
+ inSubquery.values.forall(checkExpression) && checkExpression(inSubquery.query)
+
+ private def checkListQuery(listQuery: ListQuery) = checkOperator(listQuery.plan)
+
+ private def checkExists(exists: Exists) = checkOperator(exists.plan)
+
+ private def checkOuterReference(outerReference: OuterReference) =
+ checkExpression(outerReference.e)
+
private def checkGetViewColumnBynameAndOrdinal(
getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal) = true
- private def checkConfValues() =
- // Case sensitive analysis is not supported.
- !conf.caseSensitiveAnalysis &&
- // Case-sensitive inference is not supported for Hive table schema.
- conf.caseSensitiveInferenceMode == HiveCaseSensitiveInferenceMode.NEVER_INFER &&
- // Legacy CTE resolution modes are not supported.
- !conf.getConf(SQLConf.LEGACY_INLINE_CTE_IN_COMMANDS) &&
- LegacyBehaviorPolicy.withName(conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY)) ==
- LegacyBehaviorPolicy.CORRECTED
-
- private def checkVariables() =
- catalogManager.tempVariableManager.isEmpty &&
- SqlScriptingLocalVariableManager.get().forall(_.isEmpty)
+ /**
+ * Most of the expressions come from resolving the [[UnresolvedFunction]], but here we have some
+ * popular expressions allowlist for two reasons:
+ * 1. Some of them are allocated in the Parser;
+ * 2. To allow the resolution of resolved DataFrame subtrees.
+ */
+ private def isGenerallySupportedExpression(expression: Expression): Boolean = {
+ expression match {
+ // Math
+ case _: UnaryMinus | _: BinaryArithmetic | _: LeafMathExpression | _: UnaryMathExpression |
+ _: UnaryLogExpression | _: BinaryMathExpression | _: BitShiftOperation | _: RoundCeil |
+ _: Conv | _: RoundBase | _: Factorial | _: Bin | _: Hex | _: Unhex | _: WidthBucket =>
+ true
+ // Strings
+ case _: Collate | _: Collation | _: ResolvedCollation | _: UnresolvedCollation | _: Concat |
+ _: Mask | _: ConcatWs | _: Elt | _: Upper | _: Lower | _: BinaryPredicate |
+ _: StringPredicate | _: IsValidUTF8 | _: MakeValidUTF8 | _: ValidateUTF8 |
+ _: TryValidateUTF8 | _: StringReplace | _: Overlay | _: StringTranslate | _: FindInSet |
+ _: String2TrimExpression | _: StringTrimBoth | _: StringInstr | _: SubstringIndex |
+ _: StringLocate | _: StringLPad | _: BinaryPad | _: StringRPad | _: FormatString |
+ _: InitCap | _: StringRepeat | _: StringSpace | _: Substring | _: Right | _: Left |
+ _: Length | _: BitLength | _: OctetLength | _: Levenshtein | _: SoundEx | _: Ascii |
+ _: Chr | _: Base64 | _: UnBase64 | _: Decode | _: StringDecode | _: Encode | _: ToBinary |
+ _: FormatNumber | _: Sentences | _: StringSplitSQL | _: SplitPart | _: Empty2Null |
+ _: Luhncheck =>
+ true
+ // Datetime
+ case _: TimeZoneAwareExpression =>
+ true
+ // Decimal
+ case _: UnscaledValue | _: MakeDecimal | _: CheckOverflow | _: CheckOverflowInSum |
+ _: DecimalAddNoOverflowCheck |
+ _: DecimalDivideWithOverflowCheck =>
+ true
+ // Interval
+ case _: ExtractIntervalPart[_] | _: IntervalNumOperation | _: MultiplyInterval |
+ _: DivideInterval | _: TryMakeInterval | _: MakeInterval | _: MakeDTInterval |
+ _: MakeYMInterval | _: MultiplyYMInterval | _: MultiplyDTInterval | _: DivideYMInterval |
+ _: DivideDTInterval =>
+ true
+ // Number format
+ case _: ToNumber | _: TryToNumber | _: ToCharacter =>
+ true
+ // Random
+ case _: Rand | _: Randn | _: Uniform | _: RandStr =>
+ true
+ // Regexp
+ case _: Like | _: ILike | _: LikeAll | _: NotLikeAll | _: LikeAny | _: NotLikeAny | _: RLike |
+ _: StringSplit | _: RegExpReplace | _: RegExpExtract | _: RegExpExtractAll |
+ _: RegExpCount | _: RegExpSubStr | _: RegExpInStr =>
+ true
+ // JSON
+ case _: JsonToStructs | _: StructsToJson |
+ _: SchemaOfJson | _: JsonObjectKeys | _: LengthOfJsonArray =>
+ true
+ // CSV
+ case _: SchemaOfCsv | _: StructsToCsv | _: CsvToStructs =>
+ true
+ // URL
+ case _: TryParseUrl | _: ParseUrl | _: UrlEncode | _: UrlDecode | _: TryUrlDecode =>
+ true
+ // XML
+ case _: XmlToStructs | _: SchemaOfXml | _: StructsToXml =>
+ true
+ // Misc
+ case _: TaggingExpression =>
+ true
+ case _ =>
+ false
+ }
+ }
+
+ private def detectUnsupportedConf(): Option[String] = {
+ if (conf.caseSensitiveAnalysis) {
+ Some("caseSensitiveAnalysis")
+ } else if (conf.caseSensitiveInferenceMode != HiveCaseSensitiveInferenceMode.NEVER_INFER) {
+ Some("hiveCaseSensitiveInferenceMode")
+ } else if (conf.getConf(SQLConf.LEGACY_INLINE_CTE_IN_COMMANDS)) {
+ Some("legacyInlineCTEInCommands")
+ } else if (conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY) !=
+ LegacyBehaviorPolicy.CORRECTED) {
+ Some("legacyCTEPrecedencePolicy")
+ } else {
+ None
+ }
+ }
+
+ private def checkTempVariables() =
+ catalogManager.tempVariableManager.isEmpty
+
+ private def checkScriptingVariables() =
+ SqlScriptingLocalVariableManager.get().forall(_.isEmpty)
+
+ private def tryThrowUnsupportedSinglePassAnalyzerFeature(operator: LogicalPlan): Unit = {
+ tryThrowUnsupportedSinglePassAnalyzerFeature(s"${operator.getClass} operator resolution")
+ }
+
+ private def tryThrowUnsupportedSinglePassAnalyzerFeature(expression: Expression): Unit = {
+ tryThrowUnsupportedSinglePassAnalyzerFeature(s"${expression.getClass} expression resolution")
+ }
+
+ private def tryThrowUnsupportedSinglePassAnalyzerFeature(feature: String): Unit = {
+ if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_THROW_FROM_RESOLVER_GUARD)) {
+ throw QueryCompilationErrors.unsupportedSinglePassAnalyzerFeature(feature)
+ }
+ }
}
object ResolverGuard {
@@ -276,16 +460,12 @@ object ResolverGuard {
private val UNSUPPORTED_ATTRIBUTE_NAMES = {
val map = new IdentifierMap[Unit]()
- /**
- * Some SQL functions can be called without the braces and thus they are found in the
- * parsed operator as UnresolvedAttributes. This list contains the names of those functions
- * so we can reject them. Find more information in [[ColumnResolutionHelper.literalFunctions]].
- */
- map += ("current_date", ())
- map += ("current_timestamp", ())
+ // Not supported until we support their ''real'' function counterparts.
map += ("current_user", ())
map += ("user", ())
map += ("session_user", ())
+
+ // Not supported until we support GroupingSets/Cube/Rollup.
map += ("grouping__id", ())
/**
@@ -319,7 +499,6 @@ object ResolverGuard {
// Functions that are not resolved properly.
map += ("collate", ())
map += ("json_tuple", ())
- map += ("schema_of_unstructured_agg", ())
// Functions that produce wrong schemas/plans because of alias assignment.
map += ("from_json", ())
map += ("schema_of_json", ())
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverMetricTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverMetricTracker.scala
new file mode 100644
index 0000000000000..4d6f4bfdc3934
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverMetricTracker.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.QueryPlanningTracker
+import org.apache.spark.sql.catalyst.rules.QueryExecutionMetering
+
+/**
+ * Trait for tracking and logging timing metrics for single-pass resolver.
+ */
+trait ResolverMetricTracker {
+
+ /**
+ * Log timing metrics for single-pass analyzer. In order to utilize existing logging
+ * infrastructure, we are going to log single-pass metrics as if single pass resolver was a
+ * standalone Catalyst rule. We log every run of single-pass resolver as effective.
+ */
+ protected def recordMetrics[R](tracker: QueryPlanningTracker)(body: => R): R =
+ QueryPlanningTracker.withTracker(tracker) {
+ val startTime = System.nanoTime()
+
+ val result = body
+
+ val runTime = System.nanoTime() - startTime
+
+ collectQueryExecutionMetrics(runTime)
+ tracker.recordRuleInvocation(
+ rule = ResolverMetricTracker.SINGLE_PASS_RESOLVER_METRIC_LOGGING_ALIAS,
+ timeNs = runTime,
+ effective = true
+ )
+
+ result
+ }
+
+ private def collectQueryExecutionMetrics(runTime: Long) = {
+ val queryExecutionMetrics = QueryExecutionMetering.INSTANCE
+
+ queryExecutionMetrics.incNumEffectiveExecution(
+ ResolverMetricTracker.SINGLE_PASS_RESOLVER_METRIC_LOGGING_ALIAS
+ )
+ queryExecutionMetrics.incTimeEffectiveExecutionBy(
+ ResolverMetricTracker.SINGLE_PASS_RESOLVER_METRIC_LOGGING_ALIAS,
+ runTime
+ )
+ queryExecutionMetrics.incNumExecution(
+ ResolverMetricTracker.SINGLE_PASS_RESOLVER_METRIC_LOGGING_ALIAS
+ )
+ queryExecutionMetrics.incExecutionTimeBy(
+ ResolverMetricTracker.SINGLE_PASS_RESOLVER_METRIC_LOGGING_ALIAS,
+ runTime
+ )
+ }
+}
+
+object ResolverMetricTracker {
+
+ /**
+ * Name under which single-pass resolver metrics will show up.
+ */
+ val SINGLE_PASS_RESOLVER_METRIC_LOGGING_ALIAS = "SinglePassResolver"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala
index 99f7409d31fa3..16f3a9f1444fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala
@@ -17,58 +17,65 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.{QueryPlanningTracker, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, CleanupAliases}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
/**
* Wrapper class for [[Resolver]] and single-pass resolution. This class encapsulates single-pass
- * resolution and post-processing of resolved plan. This post-processing is necessary in order to
+ * resolution, rewriting and validation of resolved plan. The plan rewrite is necessary in order to
* either fully resolve the plan or stay compatible with the fixed-point analyzer.
*/
class ResolverRunner(
resolver: Resolver,
extendedResolutionChecks: Seq[LogicalPlan => Unit] = Seq.empty
-) extends SQLConfHelper {
+) extends ResolverMetricTracker
+ with SQLConfHelper {
- private val resolutionPostProcessingExecutor = new RuleExecutor[LogicalPlan] {
- override def batches: Seq[Batch] = Seq(
- Batch("Post-process", Once, CleanupAliases)
- )
- }
+ /**
+ * Sequence of post-resolution rules that should be applied on the result of single-pass
+ * resolution.
+ */
+ private val planRewriteRules: Seq[Rule[LogicalPlan]] = Seq(
+ PruneMetadataColumns,
+ CleanupAliases
+ )
/**
- * Entry point for the resolver. This method performs following 3 steps:
- * - Resolves the plan in a bottom-up, single-pass manner.
- * - Validates the result of single-pass resolution.
- * - Applies necessary post-processing rules.
+ * `planRewriter` is used to rewrite the plan and the subqueries inside by applying
+ * `planRewriteRules`.
*/
- def resolve(
- plan: LogicalPlan,
- analyzerBridgeState: Option[AnalyzerBridgeState] = None): LogicalPlan = {
- AnalysisContext.withNewAnalysisContext {
- val resolvedPlan = resolver.lookupMetadataAndResolve(plan, analyzerBridgeState)
- if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_VALIDATION_ENABLED)) {
- val validator = new ResolutionValidator
- validator.validatePlan(resolvedPlan)
- }
- finishResolution(resolvedPlan)
- }
- }
+ private val planRewriter = new PlanRewriter(planRewriteRules)
/**
- * This method performs necessary post-processing rules that aren't suitable for single-pass
- * resolver. We apply these rules after the single-pass has finished resolution to stay
- * compatible with fixed-point analyzer.
+ * Entry point for the resolver. This method performs following 4 steps:
+ * - Resolves the plan in a bottom-up using [[Resolver]], single-pass manner.
+ * - Rewrites the plan using rules configured in the [[planRewriter]].
+ * - Validates the final result internally using [[ResolutionValidator]].
+ * - Validates the final result using [[extendedResolutionChecks]].
*/
- private def finishResolution(plan: LogicalPlan): LogicalPlan = {
- val planWithPostProcessing = resolutionPostProcessingExecutor.execute(plan)
+ def resolve(
+ plan: LogicalPlan,
+ analyzerBridgeState: Option[AnalyzerBridgeState] = None,
+ tracker: QueryPlanningTracker = new QueryPlanningTracker): LogicalPlan =
+ recordMetrics(tracker) {
+ AnalysisContext.withNewAnalysisContext {
+ val resolvedPlan = resolver.lookupMetadataAndResolve(plan, analyzerBridgeState)
+
+ val rewrittenPlan = planRewriter.rewriteWithSubqueries(resolvedPlan)
- for (rule <- extendedResolutionChecks) {
- rule(planWithPostProcessing)
+ if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_VALIDATION_ENABLED)) {
+ val validator = new ResolutionValidator
+ validator.validatePlan(rewrittenPlan)
+ }
+
+ for (rule <- extendedResolutionChecks) {
+ rule(rewrittenPlan)
+ }
+
+ rewrittenPlan
+ }
}
- planWithPostProcessing
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesExpressionChildren.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesExpressionChildren.scala
index c170941ce5348..287164b80bbc5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesExpressionChildren.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesExpressionChildren.scala
@@ -17,17 +17,96 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import org.apache.spark.sql.catalyst.expressions.Expression
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{
+ BinaryExpression,
+ Expression,
+ QuaternaryExpression,
+ TernaryExpression,
+ UnaryExpression
+}
trait ResolvesExpressionChildren {
+ /**
+ * Resolves [[UnaryExpression]] children and returns its copy with children resolved.
+ */
+ protected def withResolvedChildren(
+ unresolvedExpression: UnaryExpression,
+ resolveChild: Expression => Expression): Expression = {
+ val newChildren = Seq(resolveChild(unresolvedExpression.child))
+ unresolvedExpression.withNewChildren(newChildren)
+ }
+
+ /**
+ * Resolves [[BinaryExpression]] children and returns its copy with children resolved.
+ */
+ protected def withResolvedChildren(
+ unresolvedExpression: BinaryExpression,
+ resolveChild: Expression => Expression): Expression = {
+ val newChildren =
+ Seq(resolveChild(unresolvedExpression.left), resolveChild(unresolvedExpression.right))
+ unresolvedExpression.withNewChildren(newChildren)
+ }
+
+ /**
+ * Resolves [[TernaryExpression]] children and returns its copy with children resolved.
+ */
+ protected def withResolvedChildren(
+ unresolvedExpression: TernaryExpression,
+ resolveChild: Expression => Expression): Expression = {
+ val newChildren = Seq(
+ resolveChild(unresolvedExpression.first),
+ resolveChild(unresolvedExpression.second),
+ resolveChild(unresolvedExpression.third)
+ )
+ unresolvedExpression.withNewChildren(newChildren)
+ }
+
+ /**
+ * Resolves [[QuaternaryExpression]] children and returns its copy with children resolved.
+ */
+ protected def withResolvedChildren(
+ unresolvedExpression: QuaternaryExpression,
+ resolveChild: Expression => Expression): Expression = {
+ val newChildren = Seq(
+ resolveChild(unresolvedExpression.first),
+ resolveChild(unresolvedExpression.second),
+ resolveChild(unresolvedExpression.third),
+ resolveChild(unresolvedExpression.fourth)
+ )
+ unresolvedExpression.withNewChildren(newChildren)
+ }
+
/**
* Resolves generic [[Expression]] children and returns its copy with children resolved.
*/
- protected def withResolvedChildren[ExpressionType <: Expression](
- unresolvedExpression: ExpressionType,
- resolveChild: Expression => Expression): ExpressionType = {
- val newChildren = unresolvedExpression.children.map(resolveChild(_))
- unresolvedExpression.withNewChildren(newChildren).asInstanceOf[ExpressionType]
+ protected def withResolvedChildren(
+ unresolvedExpression: Expression,
+ resolveChild: Expression => Expression): Expression = unresolvedExpression match {
+ case unaryExpression: UnaryExpression =>
+ withResolvedChildren(unaryExpression, resolveChild)
+ case binaryExpression: BinaryExpression =>
+ withResolvedChildren(binaryExpression, resolveChild)
+ case ternaryExpression: TernaryExpression =>
+ withResolvedChildren(ternaryExpression, resolveChild)
+ case quaternaryExpression: QuaternaryExpression =>
+ withResolvedChildren(quaternaryExpression, resolveChild)
+ case _ =>
+ withResolvedChildrenImpl(unresolvedExpression, resolveChild)
+ }
+
+ private def withResolvedChildrenImpl(
+ unresolvedExpression: Expression,
+ resolveChild: Expression => Expression): Expression = {
+ val newChildren = new mutable.ArrayBuffer[Expression](unresolvedExpression.children.size)
+
+ val childrenIterator = unresolvedExpression.children.iterator
+ while (childrenIterator.hasNext) {
+ newChildren += resolveChild(childrenIterator.next())
+ }
+
+ unresolvedExpression.withNewChildren(newChildren.toSeq)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesNameByHiddenOutput.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesNameByHiddenOutput.scala
new file mode 100644
index 0000000000000..3d9f964a7b897
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesNameByHiddenOutput.scala
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.util._
+
+/**
+ * [[ResolvesNameByHiddenOutput]] is used by resolvers for operators that are able to resolve
+ * attributes in its expression tree from hidden output or that can reference expressions not
+ * present in child's output. Update child operator's output list and place a [[Project]] node on
+ * top of original operator node with the original output of an operator's child.
+ *
+ * For example, in a following query:
+ *
+ * {{{
+ * SELECT
+ * t1.key
+ * FROM
+ * t1 FULL OUTER JOIN t2 USING (key)
+ * WHERE
+ * t1.key NOT LIKE 'bb.%';
+ * }}}
+ *
+ * Plan without adding missing attributes would be:
+ *
+ * {{{
+ * +- Project [key#1]
+ * +- Filter NOT key#1 LIKE bb.%
+ * +- Project [coalesce(key#1, key#2) AS key#3, __key#1__, __key#2__]
+ * +- Join FullOuter, (key#1 = key#2)
+ * :- SubqueryAlias t1
+ * : +- Relation t1[key#1]
+ * +- SubqueryAlias t2
+ * +- Relation t2[key#2]
+ * }}}
+ *
+ * NOTE: __#key1__ and __key#2__ at the end of inner [[Project]] are metadata columns from the
+ * full outer join. Even though they are present in the [[Project]] in single-pass, fixed-point
+ * adds these columns after resolving missing input, so duplication of some columns is possible.
+ * In order to stay fully compatible between single-pass and fixed-point, we add both missing
+ * attributes and these metadata columns. We mimic fixed-point behavior by putting metadata
+ * columns in [[NameScope.hiddenOutput]] instead of [[NameScope.output]].
+ *
+ * In the plan above, [[Filter]] requires key#1 in its condition, but key#1 is __not__ available
+ * in the below [[Project]]'s output, even though key#1 is available in [[Join]]'s hidden output.
+ * Because of that, we need to place key#1 in the project list, after original project list
+ * expressions, but before metadata columns (to remain compatible with fixed-point). In order to
+ * preserve initial output of [[Filter]], we place a [[Project]] node on top of this [[Filter]],
+ * whose project list is the original output of the [[Project]] __below__ [[Filter]] (in this
+ * case - key#3 and metadata columns key#1 and key#2).
+ *
+ * Therefore, the plan becomes:
+ *
+ * {{{
+ * +- Project [key#1]
+ * +- Project [key#3, key#1, key#2]
+ * +- Filter NOT key#1 LIKE bb.%
+ * +- Project [coalesce(key#1, key#2) AS key#3, key#1, key#1, key#2]
+ * +- Join FullOuter, (key#1 = key#2)
+ * :- SubqueryAlias t1
+ * : +- Relation t1[key#1]
+ * +- SubqueryAlias t2
+ * +- Relation t2[key#2]
+ * }}}
+ *
+ * Query below exhibits similar behavior when [[Sort]] operator resolves an attribute using hidden
+ * output:
+ *
+ * {{{ SELECT col1 FROM VALUES (1, 2) ORDER BY col2; }}}
+ *
+ * Unresolved plan would be:
+ *
+ * {{{
+ * Sort [col2 ASC NULLS FIRST], true
+ * +- Project [col1]
+ * +- LocalRelation [col1, col2]
+ * }}}
+ *
+ * As it can be seen, attribute `col2` used in [[Sort]] can't be resolved using the [[Project]]
+ * output (which is [`col1`]), so it has to be resolved using the hidden output (which is
+ * propagated from [[LocalRelation]] and is [`col1`, `col2`]). As it's been shown in the previous
+ * example, `col2` has to be added to [[Project]] list and a [[Project]] with original output of
+ * the [[Project]] below [[Sort]] is added as a top node. Because of that, analyzed plan is:
+ *
+ * {{{
+ * Project [col1]
+ * +- Sort [col2 ASC NULLS FIRST], true
+ * +- Project [col1, col2]
+ * +- LocalRelation [col1, col2]
+ * }}}
+ *
+ * Another example is when [[Sort]] order expression is an [[AggregateExpression]] which is not
+ * present in the [[Aggregate.aggregateExpressions]]:
+ *
+ * {{{
+ * SELECT col1 FROM VALUES (1) GROUP BY col1 ORDER BY sum(col1);
+ * }}}
+ *
+ * In this example `sum(col1)` should be added to child's output and a [[Project]] node should be
+ * added on top of the [[Sort]] node to preserve the original output of the [[Aggregate]] node:
+ *
+ * Project [col1]
+ * +- Sort [sum(col1)#... ASC NULLS FIRST], true
+ * +- Aggregate [col1], [col1, sum(col1) AS sum(col1)#...]
+ * +- LocalRelation [col1]
+ */
+trait ResolvesNameByHiddenOutput {
+ protected val scopes: NameScopeStack
+
+ /**
+ * If the child of an operator is a [[Project]] or an [[Aggregate]] and that operator has missing
+ * expressions, insert the missing expressions in the output list of the operator.
+ * In order to stay compatible with fixed-point, missing expressions are inserted after the
+ * original output list, but before any qualified access only columns that have been added as
+ * part of resolution from hidden output.
+ */
+ def insertMissingExpressions(
+ operator: LogicalPlan,
+ missingExpressions: Seq[NamedExpression]): LogicalPlan =
+ operator match {
+ case operator @ (_: Project | _: Aggregate) if missingExpressions.nonEmpty =>
+ expandOperatorsOutputList(operator, missingExpressions)
+ case other => other
+ }
+
+ private def expandOperatorsOutputList(
+ operator: LogicalPlan,
+ missingExpressions: Seq[NamedExpression]): LogicalPlan = {
+ val (metadataCols, nonMetadataCols) = operator match {
+ case project: Project =>
+ project.projectList.partition(_.toAttribute.qualifiedAccessOnly)
+ case aggregate: Aggregate =>
+ aggregate.aggregateExpressions.partition(_.toAttribute.qualifiedAccessOnly)
+ }
+
+ val newOutputList = nonMetadataCols ++ missingExpressions ++ metadataCols
+ val newOperator = operator match {
+ case aggregate: Aggregate =>
+ aggregate.copy(aggregateExpressions = newOutputList)
+ case project: Project =>
+ project.copy(projectList = newOutputList)
+ }
+
+ newOperator
+ }
+
+ /**
+ * If [[missingExpressions]] is not empty, output of an operator has been changed by
+ * [[insertMissingExpressions]]. Therefore, we need to restore the original output, by placing a
+ * [[Project]] on top of an original node, with original's node output. Additionally, we append
+ * all qualified access only columns from hidden output, because they may be needed in upper
+ * operators (if not, they will be pruned away in [[PruneMetadataColumns]]).
+ */
+ def retainOriginalOutput(
+ operator: LogicalPlan,
+ missingExpressions: Seq[NamedExpression]): LogicalPlan = {
+ if (missingExpressions.isEmpty) {
+ operator
+ } else {
+ val project = Project(
+ scopes.current.output.map(_.asInstanceOf[NamedExpression]) ++ scopes.current.hiddenOutput
+ .filter(_.qualifiedAccessOnly),
+ operator
+ )
+ scopes.overwriteCurrent(output = Some(project.projectList.map(_.toAttribute)))
+ project
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SemanticComparator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SemanticComparator.scala
new file mode 100644
index 0000000000000..6ec34cf8511e9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SemanticComparator.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import java.util.{ArrayList, HashMap}
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+/**
+ * [[SemanticComparator]] is a tool to compare expressions semantically to a predefined sequence
+ * of `targetExpressions`. Semantic comparison is based on [[QueryPlan.canonicalized]] - for
+ * example, `col1 + 1 + col2` is semantically equal to `1 + col2 + col1`. To speed up slow tree
+ * traversals and expression node field comparisons, we cache the semantic hashes (which is
+ * simply a hash of a canonicalized subtree) and use them for O(1) indexing. If the hashes don't
+ * match, we perform an early return. Otherwise, we invoke the heavy [[Expression.semanticEquals]]
+ * method to make sure that expression trees are indeed identical.
+ */
+class SemanticComparator(targetExpressions: Seq[Expression]) {
+ private val targetExpressionsBySemanticHash =
+ new HashMap[Int, ArrayList[Expression]](targetExpressions.size)
+
+ for (targetExpression <- targetExpressions) {
+ targetExpressionsBySemanticHash
+ .computeIfAbsent(targetExpression.semanticHash(), _ => new ArrayList[Expression])
+ .add(targetExpression)
+ }
+
+ /**
+ * Returns the first expression in `targetExpressions` that is semantically equal to the given
+ * `expression`. If no such expression is found, returns `None`.
+ */
+ def collectFirst(expression: Expression): Option[Expression] = {
+ targetExpressionsBySemanticHash.get(expression.semanticHash()) match {
+ case null =>
+ None
+ case targetExpressions =>
+ val iter = targetExpressions.iterator
+ var matchedExpression: Option[Expression] = None
+ while (iter.hasNext && matchedExpression.isEmpty) {
+ val element = iter.next
+ if (element.semanticEquals(expression)) {
+ matchedExpression = Some(element)
+ }
+ }
+ matchedExpression
+ }
+ }
+
+ /**
+ * Use the previously constructed `targetExpressionsBySemanticHash` to check if the given
+ * `expression` is semantically equal to any of the target expressions.
+ */
+ def exists(expression: Expression): Boolean = {
+ collectFirst(expression).nonEmpty
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SetOperationLikeResolver.scala
similarity index 52%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SetOperationLikeResolver.scala
index 0e4eed3c20f15..216cdc679fa92 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SetOperationLikeResolver.scala
@@ -19,26 +19,32 @@ package org.apache.spark.sql.catalyst.analysis.resolver
import java.util.HashSet
-import org.apache.spark.sql.catalyst.analysis.{
- withPosition,
- AnsiTypeCoercion,
- TypeCoercion,
- TypeCoercionBase
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion, TypeCoercionBase}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, ExprId}
+import org.apache.spark.sql.catalyst.plans.logical.{
+ Except,
+ Intersect,
+ LogicalPlan,
+ Project,
+ SetOperation,
+ Union
}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, ExprId}
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Project, Union}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.{DataType, MetadataBuilder}
+import org.apache.spark.sql.types.{DataType, MapType, MetadataBuilder, VariantType}
/**
- * The [[UnionResolver]] performs [[Union]] operator resolution. This operator has 2+
- * children. Resolution involves checking and normalizing child output attributes
- * (data types and nullability).
+ * The [[SetOperationLikeResolver]] performs [[Union]], [[Intersect]] or [[Except]] operator
+ * resolution. These operators have 2+ children. Resolution involves checking and normalizing child
+ * output attributes (data types and nullability).
*/
-class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
- extends TreeNodeResolver[Union, Union] {
+class SetOperationLikeResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
+ extends TreeNodeResolver[LogicalPlan, LogicalPlan] {
private val expressionIdAssigner = expressionResolver.getExpressionIdAssigner
+ private val autoGeneratedAliasProvider = new AutoGeneratedAliasProvider(expressionIdAssigner)
private val scopes = resolver.getNameScopes
+ private val cteRegistry = resolver.getCteRegistry
private val typeCoercion: TypeCoercionBase =
if (conf.ansiEnabled) {
AnsiTypeCoercion
@@ -47,13 +53,9 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
}
/**
- * Resolve the [[Union]] operator:
- * - Retrieve old output and child outputs if the operator is already resolved. This is relevant
- * for partially resolved subtrees from DataFrame programs.
+ * Resolve the [[Union]], [[Intersect]] or [[Except]] operators:
* - Resolve each child in the context of a) New [[NameScope]] b) New [[ExpressionIdAssigner]]
- * mapping. Collect child outputs to coerce them later.
- * - Perform projection-based expression ID deduplication if required. This is a hack to stay
- * compatible with fixed-point [[Analyzer]].
+ * mapping c) CTE scope. Collect child outputs to coerce them later.
* - Perform individual output deduplication to handle the distinct union case described in
* [[performIndividualOutputExpressionIdDeduplication]] scaladoc.
* - Validate that child outputs have same length or throw "NUM_COLUMNS_MISMATCH" otherwise.
@@ -61,118 +63,74 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
* [[typeCoercion.findWiderTypeForTwo]] or throw "INCOMPATIBLE_COLUMN_TYPE" if coercion fails.
* - Add [[Project]] with [[Cast]] on children needing attribute data type widening.
* - Assert that coerced outputs don't have conflicting expression IDs.
- * - Merge transformed outputs: For each column, merge child attributes' types using
- * [[StructType.unionLikeMerge]]. Mark column as nullable if any child attribute is.
+ * - Merge transformed outputs using a separate logic for each operator type.
* - Store merged output in current [[NameScope]].
* - Create a new mapping in [[ExpressionIdAssigner]] using the coerced and validated outputs.
- * - Return the resolved [[Union]] with new children.
+ * - Return the resolved operator with new children optionally wrapped in [[WithCTE]]. See
+ * [[CteScope]] scaladoc for more info.
*/
- override def resolve(unresolvedUnion: Union): Union = {
- val (oldOutput, oldChildOutputs) = if (unresolvedUnion.resolved) {
- (Some(unresolvedUnion.output), Some(unresolvedUnion.children.map(_.output)))
- } else {
- (None, None)
- }
+ override def resolve(unresolvedOperator: LogicalPlan): LogicalPlan = {
+ val (resolvedChildren, childOutputs) = resolveChildren(unresolvedOperator)
- val (resolvedChildren, childOutputs) = unresolvedUnion.children.zipWithIndex.map {
- case (unresolvedChild, childIndex) =>
- scopes.withNewScope {
- expressionIdAssigner.withNewMapping(isLeftmostChild = (childIndex == 0)) {
- val resolvedChild = resolver.resolve(unresolvedChild)
- (resolvedChild, scopes.top.output)
- }
- }
- }.unzip
+ expressionIdAssigner.createMappingFromChildMappings()
- val (projectBasedDeduplicatedChildren, projectBasedDeduplicatedChildOutputs) =
- performProjectionBasedExpressionIdDeduplication(
- resolvedChildren,
- childOutputs,
- oldChildOutputs
- )
val (deduplicatedChildren, deduplicatedChildOutputs) =
performIndividualOutputExpressionIdDeduplication(
- projectBasedDeduplicatedChildren,
- projectBasedDeduplicatedChildOutputs
+ resolvedChildren,
+ childOutputs,
+ unresolvedOperator
)
- val (newChildren, newChildOutputs) = if (needToCoerceChildOutputs(deduplicatedChildOutputs)) {
- coerceChildOutputs(
- deduplicatedChildren,
- deduplicatedChildOutputs,
- validateAndDeduceTypes(unresolvedUnion, deduplicatedChildOutputs)
- )
- } else {
- (deduplicatedChildren, deduplicatedChildOutputs)
- }
+ val (newChildren, newChildOutputs) =
+ if (needToCoerceChildOutputs(deduplicatedChildOutputs, unresolvedOperator)) {
+ coerceChildOutputs(
+ deduplicatedChildren,
+ deduplicatedChildOutputs,
+ validateAndDeduceTypes(unresolvedOperator, deduplicatedChildOutputs)
+ )
+ } else {
+ (deduplicatedChildren, deduplicatedChildOutputs)
+ }
ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds(newChildOutputs)
- withPosition(unresolvedUnion) {
- scopes.overwriteTop(Union.mergeChildOutputs(newChildOutputs))
- }
+ val output = mergeChildOutputs(unresolvedOperator, newChildOutputs)
+ scopes.overwriteCurrent(output = Some(output), hiddenOutput = Some(output))
- expressionIdAssigner.createMapping(scopes.top.output, oldOutput)
+ validateOutputs(unresolvedOperator, output)
- unresolvedUnion.copy(children = newChildren)
+ val resolvedOperator = unresolvedOperator.withNewChildren(newChildren)
+
+ cteRegistry.currentScope.tryPutWithCTE(
+ unresolvedOperator = unresolvedOperator,
+ resolvedOperator = resolvedOperator
+ )
}
/**
- * Fixed-point [[Analyzer]] uses [[DeduplicateRelations]] rule to handle duplicate expression IDs
- * in multi-child operator outputs. For [[Union]]s it uses a "projection-based deduplication",
- * i.e. places another [[Project]] operator with new [[Alias]]es on the right child if duplicate
- * expression IDs detected. New [[Alias]] "covers" the original attribute with new expression ID.
- * This is done for all child operators except [[LeafNode]]s.
+ * Resolve `unresolvedSetOperationLike`'s children in the context of new [[NameScope]],
+ * [[ExpressionIdAssigner]] mapping and [[CteScope]].
*
- * We don't need this operation in single-pass [[Resolver]], since we have
- * [[ExpressionIdAssigner]] for expression ID deduplication, but perform it nevertheless to stay
- * compatible with fixed-point [[Analyzer]]. Since new outputs are already deduplicated by
- * [[ExpressionIdAssigner]], we check the _old_ outputs for duplicates and place a [[Project]]
- * only if old outputs are available (i.e. we are dealing with a resolved subtree from
- * DataFrame program).
+ * [[ExpressionIdAssigner]] child mapping is collected just or the left child, because that's
+ * the only child whose expression IDs get propagated upwards through [[Union]], [[Intersect]] or
+ * [[Except]]. This is an optimization to avoid fast-growing expression ID mappings.
*/
- private def performProjectionBasedExpressionIdDeduplication(
- children: Seq[LogicalPlan],
- childOutputs: Seq[Seq[Attribute]],
- oldChildOutputs: Option[Seq[Seq[Attribute]]]
- ): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = {
- oldChildOutputs match {
- case Some(oldChildOutputs) =>
- val oldExpressionIds = new HashSet[ExprId]
-
- children
- .zip(childOutputs)
- .zip(oldChildOutputs)
- .map {
- case ((child: LeafNode, output), _) =>
- (child, output)
-
- case ((child, output), oldOutput) =>
- val oldOutputExpressionIds = new HashSet[ExprId]
-
- val hasConflicting = oldOutput.exists { oldAttribute =>
- oldOutputExpressionIds.add(oldAttribute.exprId)
- oldExpressionIds.contains(oldAttribute.exprId)
- }
-
- if (hasConflicting) {
- val newExpressions = output.map { attribute =>
- Alias(attribute, attribute.name)()
- }
- (
- Project(projectList = newExpressions, child = child),
- newExpressions.map(_.toAttribute)
- )
- } else {
- oldExpressionIds.addAll(oldOutputExpressionIds)
-
- (child, output)
- }
+ private def resolveChildren(
+ unresolvedOperator: LogicalPlan): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = {
+ unresolvedOperator.children.zipWithIndex.map {
+ case (unresolvedChild, childIndex) =>
+ scopes.withNewScope() {
+ expressionIdAssigner.withNewMapping(collectChildMapping = childIndex == 0) {
+ cteRegistry.withNewScopeUnderMultiChildOperator(
+ unresolvedOperator = unresolvedOperator,
+ unresolvedChild = unresolvedChild
+ ) {
+ val resolvedChild = resolver.resolve(unresolvedChild)
+ (resolvedChild, scopes.current.output)
+ }
}
- .unzip
- case _ =>
- (children, childOutputs)
- }
+ }
+ }.unzip
}
/**
@@ -204,6 +162,17 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
* See SPARK-37865 for more details.
*/
private def performIndividualOutputExpressionIdDeduplication(
+ children: Seq[LogicalPlan],
+ childOutputs: Seq[Seq[Attribute]],
+ unresolvedOperator: LogicalPlan
+ ): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = {
+ unresolvedOperator match {
+ case _: Union => doPerformIndividualOutputExpressionIdDeduplication(children, childOutputs)
+ case _ => (children, childOutputs)
+ }
+ }
+
+ private def doPerformIndividualOutputExpressionIdDeduplication(
children: Seq[LogicalPlan],
childOutputs: Seq[Seq[Attribute]]
): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = {
@@ -222,7 +191,11 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
.withMetadata(attribute.metadata)
.putNull("__is_duplicate")
.build()
- Alias(attribute, attribute.name)(explicitMetadata = Some(newMetadata))
+ autoGeneratedAliasProvider.newAlias(
+ child = attribute,
+ name = Some(attribute.name),
+ explicitMetadata = Some(newMetadata)
+ )
} else {
expressionIds.add(attribute.exprId)
@@ -244,31 +217,54 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
* - Output length differs between children. We will throw an appropriate error later during type
* coercion with more diagnostics.
* - Output data types differ between children. We don't care about nullability for type coercion,
- * it will be correctly assigned later by [[Union.mergeChildOutputs]].
+ * it will be correctly assigned later by [[SetOperationLikeResolver.mergeChildOutputs]].
*/
- private def needToCoerceChildOutputs(childOutputs: Seq[Seq[Attribute]]): Boolean = {
+ private def needToCoerceChildOutputs(
+ childOutputs: Seq[Seq[Attribute]],
+ unresolvedOperator: LogicalPlan): Boolean = {
val firstChildOutput = childOutputs.head
childOutputs.tail.exists { childOutput =>
childOutput.length != firstChildOutput.length ||
childOutput.zip(firstChildOutput).exists {
case (lhsAttribute, rhsAttribute) =>
- !DataType.equalsStructurally(
+ !areDataTypesCompatibleInTheContextOfOperator(
+ unresolvedOperator,
lhsAttribute.dataType,
- rhsAttribute.dataType,
- ignoreNullability = true
+ rhsAttribute.dataType
)
}
}
}
+ /**
+ * This method returns whether types are compatible in the context of the specified operator.
+ *
+ * In fixed-point we only use [[DataType.equalsStructurally]] for [[Union]] type coercion. For
+ * [[Except]] and [[Intersect]] we use [[DataTypeUtils.sameType]]. This method ensures we perform
+ * the check for whether coercion is needed in the compatible way to the fixed-point.
+ */
+ private def areDataTypesCompatibleInTheContextOfOperator(
+ unresolvedPlan: LogicalPlan,
+ lhs: DataType,
+ rhs: DataType): Boolean = {
+ unresolvedPlan match {
+ case _: Union => DataType.equalsStructurally(lhs, rhs, ignoreNullability = true)
+ case _: Except | _: Intersect => DataTypeUtils.sameType(lhs, rhs)
+ case other =>
+ throw SparkException.internalError(
+ s"Set operation resolver should not be used for ${other.nodeName}"
+ )
+ }
+ }
+
/**
* Returns a sequence of data types representing the widened data types for each column:
- * - Validates that the number of columns in each child of the `Union` operator are equal.
+ * - Validates that the number of columns in each child of the set operator is equal.
* - Validates that the data types of columns can be widened to a common type.
* - Deduces the widened data types for each column.
*/
private def validateAndDeduceTypes(
- unresolvedUnion: Union,
+ unresolvedOperator: LogicalPlan,
childOutputs: Seq[Seq[Attribute]]): Seq[DataType] = {
val childDataTypes = childOutputs.map(attributes => attributes.map(attr => attr.dataType))
@@ -281,7 +277,7 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
expectedNumColumns,
childColumnTypes,
childIndex,
- unresolvedUnion
+ unresolvedOperator
)
}
@@ -289,7 +285,7 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
case ((widenedColumnType, columnTypeForCurrentRow), columnIndex) =>
typeCoercion.findWiderTypeForTwo(widenedColumnType, columnTypeForCurrentRow).getOrElse {
throwIncompatibleColumnTypeError(
- unresolvedUnion,
+ unresolvedOperator,
columnIndex,
childIndex + 1,
widenedColumnType,
@@ -325,10 +321,10 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
attribute
} else {
outputChanged = true
- Alias(
- Cast(attribute, widenedType, Some(conf.sessionLocalTimeZone)),
- attribute.name
- )()
+ autoGeneratedAliasProvider.newAlias(
+ child = Cast(attribute, widenedType, Some(conf.sessionLocalTimeZone)),
+ name = Some(attribute.name)
+ )
}
}
@@ -341,34 +337,95 @@ class UnionResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
.unzip
}
+ /**
+ * Helper method to call appropriate object method [[mergeChildOutputs]] for each operator.
+ */
+ private def mergeChildOutputs(
+ unresolvedPlan: LogicalPlan,
+ childOutputs: Seq[Seq[Attribute]]): Seq[Attribute] = {
+ unresolvedPlan match {
+ case _: Union => Union.mergeChildOutputs(childOutputs)
+ case _: Except => Except.mergeChildOutputs(childOutputs)
+ case _: Intersect => Intersect.mergeChildOutputs(childOutputs)
+ case other =>
+ throw SparkException.internalError(
+ s"Set operation resolver should not be used for ${other.nodeName}"
+ )
+ }
+ }
+
+ /**
+ * Validate outputs of [[SetOperation]].
+ * - [[MapType]] and [[VariantType]] are currently not supported for [[SetOperations]] and we need
+ * to throw a relevant user-facing error.
+ */
+ private def validateOutputs(unresolvedPlan: LogicalPlan, output: Seq[Attribute]): Unit = {
+ unresolvedPlan match {
+ case _: SetOperation =>
+ output.find(a => hasMapType(a.dataType)).foreach { mapCol =>
+ throwUnsupportedSetOperationOnMapType(mapCol, unresolvedPlan)
+ }
+ output.find(a => hasVariantType(a.dataType)).foreach { variantCol =>
+ throwUnsupportedSetOperationOnVariantType(variantCol, unresolvedPlan)
+ }
+ case _ =>
+ }
+ }
+
+ private def throwUnsupportedSetOperationOnMapType(
+ mapCol: Attribute,
+ unresolvedPlan: LogicalPlan): Unit = {
+ throw QueryCompilationErrors.unsupportedSetOperationOnMapType(
+ mapCol = mapCol,
+ origin = unresolvedPlan.origin
+ )
+ }
+
+ private def throwUnsupportedSetOperationOnVariantType(
+ variantCol: Attribute,
+ unresolvedPlan: LogicalPlan): Unit = {
+ throw QueryCompilationErrors.unsupportedSetOperationOnVariantType(
+ variantCol = variantCol,
+ origin = unresolvedPlan.origin
+ )
+ }
+
private def throwNumColumnsMismatch(
expectedNumColumns: Int,
childColumnTypes: Seq[DataType],
columnIndex: Int,
- unresolvedUnion: Union): Unit = {
+ unresolvedOperator: LogicalPlan): Unit = {
throw QueryCompilationErrors.numColumnsMismatch(
- "UNION",
+ unresolvedOperator.nodeName.toUpperCase(),
expectedNumColumns,
columnIndex + 1,
childColumnTypes.length,
- unresolvedUnion.origin
+ unresolvedOperator.origin
)
}
private def throwIncompatibleColumnTypeError(
- unresolvedUnion: Union,
+ unresolvedOperator: LogicalPlan,
columnIndex: Int,
childIndex: Int,
widenedColumnType: DataType,
columnTypeForCurrentRow: DataType): Nothing = {
throw QueryCompilationErrors.incompatibleColumnTypeError(
- "UNION",
+ unresolvedOperator.nodeName.toUpperCase(),
columnIndex,
childIndex + 1,
widenedColumnType,
columnTypeForCurrentRow,
hint = "",
- origin = unresolvedUnion.origin
+ origin = unresolvedOperator.origin
)
}
+
+ private def hasMapType(dt: DataType): Boolean = {
+ dt.existsRecursively(_.isInstanceOf[MapType])
+ }
+
+ private def hasVariantType(dt: DataType): Boolean = {
+ dt.existsRecursively(_.isInstanceOf[VariantType])
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SortResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SortResolver.scala
new file mode 100644
index 0000000000000..53c7b3f2366c1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SortResolver.scala
@@ -0,0 +1,397 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import java.util.{HashMap, IdentityHashMap, LinkedHashMap}
+
+import scala.collection.mutable
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.analysis.{
+ NondeterministicExpressionCollection,
+ UnresolvedAttribute
+}
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ Attribute,
+ Expression,
+ ExprId,
+ NamedExpression,
+ SortOrder
+}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, Sort}
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+/**
+ * Resolves a [[Sort]] by resolving its child and order expressions.
+ */
+class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionResolver)
+ extends TreeNodeResolver[Sort, LogicalPlan]
+ with ResolvesNameByHiddenOutput {
+ override protected val scopes: NameScopeStack = operatorResolver.getNameScopes
+ private val autoGeneratedAliasProvider = new AutoGeneratedAliasProvider(
+ expressionResolver.getExpressionIdAssigner
+ )
+
+ /**
+ * Resolve [[Sort]] operator.
+ *
+ * 1. Resolve [[Sort.child]].
+ * 2. Resolve order expressions using [[ExpressionResolver.resolveExpressionTreeInOperator]].
+ * 3. In case order expressions contain only one element, `UnresolvedAttribute(ALL)`, which
+ * can't be resolved from current (nor from hidden output) - this is done using
+ * [[ResolveAsAllKeyword]], skip previous step and resolve it as an `ALL` keyword (by copying
+ * child's output and transforming it into attributes).
+ * 4. In case there are attributes referenced in the order expressions are resolved using
+ * the hidden output (or in case we order by [[AggregateExpression]]s which are not present
+ * in [[Aggregate.aggregateExpressions]]) update the output of child operator and add a
+ * [[Project]] as a parent of [[Sort]] with original [[Project]]'s output (this is done by
+ * [[ResolvesNameByHiddenOutput]]). Query:
+ * {{{
+ * SELECT col1 FROM VALUES(1, 2) WHERE col2 > 2 ORDER BY col2;
+ * }}}
+ * Has the following unresolved plan:
+ *
+ * 'Sort ['col2 ASC NULLS FIRST], true
+ * +- 'Project ['col1]
+ * +- 'Filter ('col2 > 2)
+ * +- LocalRelation [col1#92225, col2#92226]
+ *
+ * Because `col2` from the [[Sort]] node is resolved using the hidden output, add it to the
+ * [[Project.projectList]] and add a [[Project]] with original project list as a top node:
+ *
+ * Project [col1]
+ * +- Sort [col2 ASC NULLS FIRST], true
+ * +- Project [col1, col2]
+ * +- Filter (col2 > 2)
+ * +- LocalRelation [col1, col2]
+ *
+ * Another example with ordering by [[AggregateExpression]]:
+ * {{{
+ * SELECT col1 FROM VALUES (1, 2) GROUP BY col1, col2 + 1 ORDER BY SUM(col1), col2 + 1;
+ * }}}
+ * Has the following unresolved plan:
+ *
+ * 'Sort ['SUM('col1) ASC NULLS FIRST, ('col2 + 1) ASC NULLS FIRST], true
+ * +- 'Aggregate ['col1, ('col2 + 1)], ['col1]
+ * +- LocalRelation [col1, col2]
+ *
+ * Because neither `SUM(col1)` nor `col2 + 1` from the [[Sort]] node are present in the
+ * [[Aggregate.aggregateExpressions]], add them to it and add a [[Project]] with original
+ * project list as a top node (`SUM(col2)` is extracted in the
+ * [[AggregateExpressionResolver]] whereas `col2 + 1` is extracted using
+ * `extractReferencedGroupingAndAggregateExpressions` helper method):
+ *
+ * Project [col1]
+ * +- Sort [sum(col1)#... ASC NULLS FIRST, (col2 + 1)#... ASC NULLS FIRST], true
+ * +- Aggregate [col1, (col2 + 1)],
+ * [col1, sum(col1) AS sum(col1)#..., (col2 + 1) AS (col2 + 1)#...]
+ * +- LocalRelation [col1, col2]
+ * 5. In case there are non-deterministic expressions in the order expressions, substitute them
+ * with derived attribute references to an artificial [[Project]] list.
+ */
+ override def resolve(unresolvedSort: Sort): LogicalPlan = {
+ val resolvedChild = operatorResolver.resolve(unresolvedSort.child)
+ if (canOrderByAll(unresolvedSort.order)) {
+ val sortOrder = unresolvedSort.order.head
+ val resolvedOrder =
+ scopes.current.output.map(a => sortOrder.copy(child = a.toAttribute))
+ unresolvedSort.copy(child = resolvedChild, order = resolvedOrder)
+ } else {
+ val unresolvedSortWithResolvedChild = unresolvedSort.copy(child = resolvedChild)
+
+ val (resolvedOrderExpressions, missingAttributes, aggregateExpressionsAliased) =
+ resolveOrderExpressions(
+ unresolvedSortWithResolvedChild,
+ scopes.current.output.toArray
+ )
+
+ val (finalOrderExpressions, missingExpressions) = resolvedChild match {
+ case aggregate: Aggregate =>
+ val (cleanedOrderExpressions, referencedGroupingExpressions) =
+ extractReferencedGroupingAndAggregateExpressions(aggregate, resolvedOrderExpressions)
+ (cleanedOrderExpressions, aggregateExpressionsAliased ++ referencedGroupingExpressions)
+ case other =>
+ (resolvedOrderExpressions, missingAttributes)
+ }
+
+ val resolvedChildWithMissingAttributes =
+ insertMissingExpressions(resolvedChild, missingExpressions)
+
+ val resolvedSort = unresolvedSort.copy(
+ child = resolvedChildWithMissingAttributes,
+ order = finalOrderExpressions
+ )
+
+ val sortWithOriginalOutput = retainOriginalOutput(resolvedSort, missingExpressions)
+
+ sortWithOriginalOutput match {
+ case project: Project =>
+ sortWithOriginalOutput
+ case sort: Sort =>
+ tryPullOutNondeterministic(sort)
+ }
+ }
+ }
+
+ /**
+ * Resolve order expressions of an unresolved [[Sort]], returns attributes resolved using hidden
+ * output and extracted [[AggregateExpression]]s. In case of [[UnresolvedAttribute]] resolution,
+ * respect the following order.
+ *
+ * 1. Attribute can be resolved using the current scope:
+ * {{{
+ * -- This one will resolve `col1` from the current scope (its value will be 1)
+ * SELECT col1 FROM VALUES(1, 2) WHERE (SELECT col1 FROM VALUES(3)) ORDER BY col1;
+ * }}}
+ *
+ * 2. Attribute can be resolved using the hidden output. Attribute is added to
+ * `missingAttributes` which is used to update the plan using the
+ * [[ResolvesNameByHiddenOutput]].
+ * {{{
+ * -- This one will resolve `col2` from hidden output
+ * SELECT col1 FROM VALUES(1, 2) WHERE col2 > 2 ORDER BY col2;
+ * }}}
+ *
+ * 3. In case attribute can't be resolved from output nor from hidden output, throw
+ * `UNRESOLVED_COLUMN` exception:
+ * {{{
+ * -- Following queries throw `UNRESOLVED_COLUMN` exception:
+ * SELECT col1 FROM VALUES(1,2) GROUP BY col1 HAVING col1 > 1 ORDER BY col2;
+ * SELECT col1 FROM VALUES(1) ORDER BY col2;
+ * }}}
+ */
+ private def resolveOrderExpressions(
+ unresolvedSort: Sort,
+ projectListArray: Array[NamedExpression]): (Seq[SortOrder], Seq[Attribute], Seq[Alias]) = {
+ val orderByOrdinal = conf.orderByOrdinal
+ val referencedAttributes = new HashMap[ExprId, Attribute]
+ val aggregateExpressionsAliased = new mutable.ArrayBuffer[Alias]
+
+ val resolvedSortOrder = unresolvedSort.order.map { sortOrder =>
+ val partiallyResolvedSortOrderChild =
+ expressionResolver.resolveExpressionTreeInOperator(sortOrder.child, unresolvedSort)
+
+ referencedAttributes.putAll(expressionResolver.getLastReferencedAttributes)
+ aggregateExpressionsAliased ++= expressionResolver.getLastExtractedAggregateExpressionAliases
+
+ val resolvedSortOrderChild = tryReplaceOrdinalsInSortOrderChild(
+ partiallyResolvedSortOrderChild,
+ orderByOrdinal,
+ projectListArray,
+ unresolvedSort
+ )
+
+ sortOrder.copy(child = resolvedSortOrderChild)
+ }
+
+ val missingAttributes = scopes.current.resolveMissingAttributesByHiddenOutput(
+ referencedAttributes
+ )
+
+ (resolvedSortOrder, missingAttributes, aggregateExpressionsAliased.toSeq)
+ }
+
+ /**
+ * Replaces the ordinals with the actual expressions from the resolved project list or throws if
+ * ordinal value is greater than project list size. Do this only in case `conf.orderByOrdinal`
+ * value is `true`.
+ *
+ * Example 1:
+ * {{{
+ * -- This one would order by `col1` and `col2 + 1`
+ * SELECT col1, col2 + 1 FROM VALUES(1, 2) ORDER BY col1, 2;
+ * }}}
+ *
+ * Example 2:
+ * {{{
+ * -- This one would throw `ORDER_BY_POS_OUT_OF_RANGE` error
+ * SELECT col1 FROM VALUES(1, 2) ORDER BY 2;
+ * }}}
+ */
+ private def tryReplaceOrdinalsInSortOrderChild(
+ partiallyResolvedSortOrderChild: Expression,
+ orderByOrdinal: Boolean,
+ projectListArray: Array[NamedExpression],
+ unresolvedSort: Sort): Expression = {
+ if (orderByOrdinal) {
+ TryExtractOrdinal(partiallyResolvedSortOrderChild) match {
+ case Some(ordinal) =>
+ if (ordinal > projectListArray.length) {
+ throw QueryCompilationErrors.orderByPositionRangeError(
+ ordinal,
+ projectListArray.length,
+ unresolvedSort
+ )
+ }
+ projectListArray(ordinal - 1).toAttribute
+ case None => partiallyResolvedSortOrderChild
+ }
+ } else {
+ partiallyResolvedSortOrderChild
+ }
+ }
+
+ /**
+ * Extracts the referenced grouping and aggregate expressions from the order expressions. This is
+ * used to update the output of the child operator and add a [[Project]] as a parent of [[Sort]]
+ * later during the resolution (if needed). Consider the following example:
+ * {{{
+ * SELECT col1 FROM VALUES (1, 2) GROUP BY col1, col2 ORDER BY col2;
+ * }}}
+ *
+ * The unresolved plan would look like this:
+ *
+ * 'Sort ['col2 ASC NULLS FIRST], true
+ * +- 'Aggregate ['col1, 'col2], ['col1]
+ * +- LocalRelation [col1, col2]
+ *
+ * As it can be seen, `col2` (ordering expression) is not present in the [[Aggregate]] operator
+ * , and thus we return it from this method. The plan will be altered later during the resolution
+ * using the [[ResolvesNameByHiddenOutput]] (`col2` will be added to
+ * [[Aggregate.aggregateExpressions]], [[Project]] will be added as a top node with original
+ * [[Aggregate]] output) and it will look like:
+ *
+ * Project [col1]
+ * +- Sort [col2 ASC NULLS FIRST], true
+ * +- Aggregate [col1, col2], [col1, col2]
+ * +- LocalRelation [col1, col2]
+ */
+ private def extractReferencedGroupingAndAggregateExpressions(
+ aggregate: Aggregate,
+ sortOrderEntries: Seq[SortOrder]): (Seq[SortOrder], Seq[NamedExpression]) = {
+ val aliasChildToAliasInAggregateExpressions = new IdentityHashMap[Expression, Alias]
+ val aggregateExpressionsSemanticComparator = new SemanticComparator(
+ aggregate.aggregateExpressions.map {
+ case alias: Alias =>
+ aliasChildToAliasInAggregateExpressions.put(alias.child, alias)
+ alias.child
+ case other => other
+ }
+ )
+
+ val groupingExpressionsSemanticComparator = new SemanticComparator(
+ aggregate.groupingExpressions
+ )
+
+ val referencedGroupingExpressions = new mutable.ArrayBuffer[NamedExpression]
+ val transformedSortOrderEntries = sortOrderEntries.map { sortOrder =>
+ sortOrder.copy(child = sortOrder.child.transform {
+ case expression: Expression =>
+ extractReferencedGroupingAndAggregateExpressionsFromOrderExpression(
+ expression = expression,
+ aggregateExpressionsSemanticComparator = aggregateExpressionsSemanticComparator,
+ groupingExpressionsSemanticComparator = groupingExpressionsSemanticComparator,
+ aliasChildToAliasInAggregateExpressions = aliasChildToAliasInAggregateExpressions,
+ referencedGroupingExpressions = referencedGroupingExpressions
+ )
+ })
+ }
+ (transformedSortOrderEntries, referencedGroupingExpressions.toSeq)
+ }
+
+ private def extractReferencedGroupingAndAggregateExpressionsFromOrderExpression(
+ expression: Expression,
+ aggregateExpressionsSemanticComparator: SemanticComparator,
+ groupingExpressionsSemanticComparator: SemanticComparator,
+ aliasChildToAliasInAggregateExpressions: IdentityHashMap[Expression, Alias],
+ referencedGroupingExpressions: mutable.ArrayBuffer[NamedExpression]): Expression = {
+ aggregateExpressionsSemanticComparator.collectFirst(expression) match {
+ case Some(attribute: Attribute)
+ if !aliasChildToAliasInAggregateExpressions.containsKey(attribute) =>
+ attribute
+ case Some(expression) =>
+ aliasChildToAliasInAggregateExpressions.get(expression) match {
+ case null =>
+ throw SparkException.internalError(
+ s"No parent alias for expression $expression while extracting aggregate" +
+ s"expressions in Sort operator."
+ )
+ case alias: Alias =>
+ alias.toAttribute
+ }
+ case None if groupingExpressionsSemanticComparator.exists(expression) =>
+ expression match {
+ case attribute: Attribute =>
+ referencedGroupingExpressions += attribute
+ attribute
+ case other =>
+ val alias = autoGeneratedAliasProvider.newAlias(child = other)
+ referencedGroupingExpressions += alias
+ alias.toAttribute
+ }
+ case None => expression
+ }
+ }
+
+ /**
+ * In case there are non-deterministic expressions in `order` expressions replace them with
+ * attributes created out of corresponding non-deterministic expression. Example:
+ *
+ * {{{ SELECT 1 ORDER BY RAND(); }}}
+ *
+ * This query would have the following analyzed plan:
+ *
+ * Project [1]
+ * +- Sort [_nondeterministic ASC NULLS FIRST], true
+ * +- Project [1, rand(...) AS _nondeterministic#...]
+ * +- Project [1 AS 1#...]
+ * +- OneRowRelation
+ */
+ private def tryPullOutNondeterministic(sort: Sort): LogicalPlan = {
+ val nondeterministicToAttributes: LinkedHashMap[Expression, NamedExpression] =
+ NondeterministicExpressionCollection.getNondeterministicToAttributes(
+ sort.order.map(_.child)
+ )
+
+ if (!nondeterministicToAttributes.isEmpty) {
+ val newChild = Project(
+ scopes.current.output ++ nondeterministicToAttributes.values.asScala.toSeq,
+ sort.child
+ )
+ val resolvedOrder = sort.order.map { sortOrder =>
+ sortOrder.copy(
+ child = PullOutNondeterministicExpressionInExpressionTree(
+ sortOrder.child,
+ nondeterministicToAttributes
+ )
+ )
+ }
+ val resolvedSort = sort.copy(
+ order = resolvedOrder,
+ child = newChild
+ )
+ Project(projectList = scopes.current.output, child = resolvedSort)
+ } else {
+ sort
+ }
+ }
+
+ private def canOrderByAll(expressions: Seq[SortOrder]): Boolean = {
+ val isOrderByAll = expressions match {
+ case Seq(SortOrder(unresolvedAttribute: UnresolvedAttribute, _, _, _)) =>
+ unresolvedAttribute.equalsIgnoreCase("ALL")
+ case _ => false
+ }
+ isOrderByAll && scopes.current
+ .resolveMultipartName(Seq("ALL"), canResolveNameByHiddenOutput = true)
+ .candidates
+ .isEmpty
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryExpressionResolver.scala
new file mode 100644
index 0000000000000..c36024e7269e0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryExpressionResolver.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedInlineTable, ValidateSubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{
+ AttributeReference,
+ Exists,
+ Expression,
+ InSubquery,
+ ListQuery,
+ ScalarSubquery,
+ SubExprUtils,
+ SubqueryExpression
+}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+/**
+ * [[SubqueryExpressionResolver]] resolves specific [[SubqueryExpression]]s, such as
+ * [[ScalarSubquery]], [[ListQuery]], and [[Exists]].
+ */
+class SubqueryExpressionResolver(expressionResolver: ExpressionResolver, resolver: Resolver) {
+ private val scopes = resolver.getNameScopes
+ private val traversals = expressionResolver.getExpressionTreeTraversals
+ private val cteRegistry = resolver.getCteRegistry
+ private val subqueryRegistry = resolver.getSubqueryRegistry
+ private val expressionIdAssigner = expressionResolver.getExpressionIdAssigner
+ private val typeCoercionResolver = expressionResolver.getGenericTypeCoercionResolver
+
+ /**
+ * Resolve [[ScalarSubquery]]:
+ * - Resolve the subquery plan;
+ * - Get outer references;
+ * - Type coerce it;
+ * - Validate it.
+ */
+ def resolveScalarSubquery(unresolvedScalarSubquery: ScalarSubquery): Expression = {
+ traversals.current.parentOperator match {
+ case unresolvedInlineTable: UnresolvedInlineTable =>
+ throw QueryCompilationErrors.inlineTableContainsScalarSubquery(unresolvedInlineTable)
+ case _ =>
+ }
+
+ val resolvedSubqueryExpressionPlan = resolveSubqueryExpressionPlan(
+ unresolvedScalarSubquery.plan
+ )
+
+ val resolvedScalarSubquery = unresolvedScalarSubquery.copy(
+ plan = resolvedSubqueryExpressionPlan.plan,
+ outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions
+ )
+
+ val coercedScalarSubquery =
+ typeCoercionResolver.resolve(resolvedScalarSubquery).asInstanceOf[ScalarSubquery]
+
+ validateSubqueryExpression(coercedScalarSubquery)
+
+ coercedScalarSubquery
+ }
+
+ /**
+ * Resolve [[InSubquery]]:
+ * - Resolve the underlying [[ListQuery]];
+ * - Resolve the values;
+ * - Type coerce it;
+ * - Validate it.
+ */
+ def resolveInSubquery(unresolvedInSubquery: InSubquery): Expression = {
+ val resolvedQuery =
+ expressionResolver.resolve(unresolvedInSubquery.query).asInstanceOf[ListQuery]
+
+ val resolvedValues = unresolvedInSubquery.values.map { value =>
+ expressionResolver.resolve(value)
+ }
+
+ val resolvedInSubquery =
+ unresolvedInSubquery.copy(values = resolvedValues, query = resolvedQuery)
+
+ val coercedInSubquery =
+ typeCoercionResolver.resolve(resolvedInSubquery).asInstanceOf[InSubquery]
+
+ validateSubqueryExpression(coercedInSubquery.query)
+
+ coercedInSubquery
+ }
+
+ /**
+ * Resolve [[ListSubquery]], which is always a child of the [[InSubquery]].
+ */
+ def resolveListQuery(unresolvedListQuery: ListQuery): Expression = {
+ val resolvedSubqueryExpressionPlan = resolveSubqueryExpressionPlan(unresolvedListQuery.plan)
+
+ unresolvedListQuery.copy(
+ plan = resolvedSubqueryExpressionPlan.plan,
+ outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions,
+ numCols = resolvedSubqueryExpressionPlan.output.size
+ )
+ }
+
+ /**
+ * Resolve [[Exists]] subquery:
+ * - Resolve the subquery plan;
+ * - Get outer references;
+ * - Type coerce it;
+ * - Validate it.
+ */
+ def resolveExists(unresolvedExists: Exists): Expression = {
+ val resolvedSubqueryExpressionPlan = resolveSubqueryExpressionPlan(unresolvedExists.plan)
+
+ val resolvedExists = unresolvedExists.copy(
+ plan = resolvedSubqueryExpressionPlan.plan,
+ outerAttrs = resolvedSubqueryExpressionPlan.outerExpressions
+ )
+
+ val coercedExists = typeCoercionResolver.resolve(resolvedExists).asInstanceOf[Exists]
+
+ validateSubqueryExpression(coercedExists)
+
+ coercedExists
+ }
+
+ /**
+ * Resolve [[SubqueryExpression]] plan. Subquery expressions require:
+ * - Fresh [[NameScope]] to isolate the name resolution;
+ * - Fresh [[ExpressionIdAssigner]] mapping, because it's a separate plan branch;
+ * - Fresh [[CteScope]] with the root flag set so that [[CTERelationDefs]] are merged
+ * under the subquery root.
+ * - Fresh [[SubqueryRegistry]] scope to isolate the subquery expression plan resolution.
+ */
+ private def resolveSubqueryExpressionPlan(
+ unresolvedSubqueryPlan: LogicalPlan): ResolvedSubqueryExpressionPlan = {
+ val resolvedSubqueryExpressionPlan = scopes.withNewScope(isSubqueryRoot = true) {
+ expressionIdAssigner.withNewMapping(isSubqueryRoot = true) {
+ cteRegistry.withNewScope(isRoot = true) {
+ subqueryRegistry.withNewScope() {
+ val resolvedPlan = resolver.resolve(unresolvedSubqueryPlan)
+
+ ResolvedSubqueryExpressionPlan(
+ plan = resolvedPlan,
+ output = scopes.current.output,
+ outerExpressions = getOuterExpressions(resolvedPlan)
+ )
+ }
+ }
+ }
+ }
+
+ for (expression <- resolvedSubqueryExpressionPlan.outerExpressions) {
+ expressionResolver.validateExpressionUnderSupportedOperator(expression)
+ }
+
+ resolvedSubqueryExpressionPlan
+ }
+
+ /**
+ * Get outer expressions from the subquery plan. These are the expressions that are actual
+ * [[AttributeReference]]s or [[AggregateExpression]]s with outer references in the subtree.
+ * [[AggregateExpressionResolver]] strips out the outer aggregate expression subtrees, but
+ * for [[SubqueryExpression.outerAttrs]] to be well-formed, we need to put those back. After
+ * that we validate the outer expressions.
+ *
+ * We reuse [[SubExprUtils.getOuterReferences]] for the top-down traversal to stay compatible
+ * with the fixed-point Analyzer, because the order of [[SubqueryExpression.outerAttrs]] is a
+ * part of an implicit alias name for that expression.
+ */
+ private def getOuterExpressions(subQuery: LogicalPlan): Seq[Expression] = {
+ val subqueryScope = subqueryRegistry.currentScope
+
+ SubExprUtils.getOuterReferences(subQuery).map {
+ case attribute: AttributeReference =>
+ subqueryScope.getOuterAggregateExpression(attribute.exprId).getOrElse(attribute)
+ case other =>
+ other
+ }
+ }
+
+ /**
+ * Generically validates [[SubqueryExpression]]. Should not be done in the main pass, because the
+ * logic is too sophisticated.
+ */
+ private def validateSubqueryExpression(subqueryExpression: SubqueryExpression): Unit = {
+ ValidateSubqueryExpression(traversals.current.parentOperator, subqueryExpression)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryScope.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryScope.scala
new file mode 100644
index 0000000000000..9b7517781843f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SubqueryScope.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import java.util.{ArrayDeque, ArrayList, HashMap}
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+
+/**
+ * The [[SubqueryScope]] is managed through the whole resolution process of a given
+ * [[SubqueryExpression]] plan.
+ *
+ * The reason why we need this scope is that [[AggregateExpression]]s with [[OuterReference]]s are
+ * handled in a special way. Consider this query:
+ *
+ * {{{
+ * -- t1.col2 is an outer reference
+ * SELECT col1 FROM VALUES (1, 2) t1 GROUP BY col1 HAVING (
+ * SELECT * FROM VALUES (1, 2) t2 WHERE t2.col2 == MAX(t1.col2)
+ * )
+ * }}}
+ *
+ * During the [[Exists]] resolution inside the HAVING clause we encounter "t1.col2" name, which is
+ * resolved to an [[OuterReference]]. There's an [[AggregateExpression]] on top of it. This whole
+ * expression is not local to the subquery, and thus it belongs to an outer [[Aggregate]]
+ * operator below the `HAVING` clause. We need top pull it up outside of the subquery, and insert
+ * it in the [[Aggregate]] operator. So the resolution order is as follows:
+ * - Resolve "t1.col2" to an [[OuterReference]] in [[ExpressionResolver.resolveAttribute]];
+ * - Resolve the [[AggregateExpression]] in [[AggregateExpressionResolver.resolve]];
+ * - Detect an outer reference below the aggregate expression, cut the whole subtree with outer
+ * references stripped away, alias it and insert it in this [[SubqueryScope]];
+ * - Replace the aggregate expression with an [[OuterReference]] to the [[AttributeReference]] from
+ * that artificial [[Alias]];
+ * - When the resolution of the [[SubqueryExpression]] is finished,
+ * [[SubqueryRegistry.withNewScope]] merges the lower scope to the upper one, and all the
+ * outer aggregate expression references are appended to the common
+ * [[lowerAliasedOuterAggregateExpressions]] list.
+ * - Finally, the resolution of the `HAVING` clause can insert the missing aggregate expression
+ * into the lower [[Aggregate]] operator. During this process we must call
+ * [[ExpressionIdAssigner.mapExpression]] on the new alias, because this auto-generated alias
+ * is new to the query plan, so that [[ExpressionIdAssigner]] remembers it.
+ *
+ * Notes:
+ * - Spark only supports outer aggregates in the subqueries inside `HAVING`;
+ * - The subtree under a given [[AggregateExpression]] can be arbitrary, but must contain either
+ * local or outer references, the mixed set is disallowed.
+ * - We can have several subquery expressions in HAVING clause, that's why we append outer
+ * aggregate expressions from lower scopes in [[mergeChildScope]].
+ *
+ * @param isOuterAggregateAllowed Whether outer aggregate expressions are allowed in this scope.
+ * Currenly Spark only supports those in HAVING.
+ */
+class SubqueryScope(val isOuterAggregateAllowed: Boolean = false) {
+ private val outerAliasedAggregateExpressions = new ArrayList[Alias]
+ private val outerAliasedAggregateExpressionById = new HashMap[ExprId, AggregateExpression]
+ private val lowerAliasedOuterAggregateExpressions = new ArrayList[Alias]
+
+ /**
+ * Add an outer `aggregateExpression` to this scope. The `alias` is auto-generated and will be
+ * later inserted into the outer [[Aggregate]] operator.
+ */
+ def addOuterAggregateExpression(alias: Alias, aggregateExpression: AggregateExpression): Unit = {
+ outerAliasedAggregateExpressions.add(alias)
+ outerAliasedAggregateExpressionById.put(alias.exprId, aggregateExpression)
+ }
+
+ /**
+ * Get the outer `aggregateExpression` by its `aliasId`. This is used in the
+ * [[SubqueryExpressionResolver]] to replace the collected outer references with the original
+ * [[AggregateExpression]] subtrees, which is required for a well-formed
+ * [[SubqueryExpression.outerAttrs]].
+ */
+ def getOuterAggregateExpression(aliasId: ExprId): Option[AggregateExpression] = {
+ Option(outerAliasedAggregateExpressionById.get(aliasId))
+ }
+
+ /**
+ * Get the outer aggregate expression aliased from the lower subquery scope.
+ */
+ def getLowerOuterAggregateExpressionAliases: Seq[Alias] = {
+ lowerAliasedOuterAggregateExpressions.asScala.toSeq
+ }
+
+ /**
+ * Merge `childScope` by extending our `lowerAliasedOuterAggregateExpressions` with
+ * `childScope.outerAliasedAggregateExpressions`.
+ */
+ def mergeChildScope(childScope: SubqueryScope): Unit = {
+ lowerAliasedOuterAggregateExpressions.addAll(childScope.outerAliasedAggregateExpressions)
+ }
+}
+
+/**
+ * The [[SubqueryRegistry]] manages the stack of [[SubqueryScope]]s during the resolution of
+ * the whole SQL query. Every new [[SubqueryScope]] has its own isolated scope.
+ */
+class SubqueryRegistry {
+ private val stack = new ArrayDeque[SubqueryScope]
+ stack.push(new SubqueryScope)
+
+ /**
+ * Get the current [[SubqueryScope]].
+ */
+ def currentScope: SubqueryScope = stack.peek()
+
+ /**
+ * A RAII-wrapper for pushing/popping scopes. This is used by the [[SubqueryExpressionResolver]]
+ * to create a new scope for each [[SubqueryExpression]].
+ */
+ def withNewScope[R](isOuterAggregateAllowed: Boolean = false)(body: => R): R = {
+ stack.push(new SubqueryScope(isOuterAggregateAllowed = isOuterAggregateAllowed))
+ try {
+ body
+ } finally {
+ val childScope = stack.pop()
+ currentScope.mergeChildScope(childScope)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimeAddResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimeAddResolver.scala
index 74b101273aaac..70d69daee5cb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimeAddResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimeAddResolver.scala
@@ -31,8 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, TimeAdd}
class TimeAddResolver(
expressionResolver: ExpressionResolver,
timezoneAwareExpressionResolver: TimezoneAwareExpressionResolver)
- extends TreeNodeResolver[TimeAdd, Expression]
- with ResolvesExpressionChildren {
+ extends TreeNodeResolver[TimeAdd, Expression]
+ with ResolvesExpressionChildren {
private val typeCoercionTransformations: Seq[Expression => Expression] =
if (conf.ansiEnabled) {
@@ -44,8 +44,8 @@ class TimeAddResolver(
new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations)
override def resolve(unresolvedTimeAdd: TimeAdd): Expression = {
- val timeAddWithResolvedChildren: TimeAdd =
- withResolvedChildren(unresolvedTimeAdd, expressionResolver.resolve)
+ val timeAddWithResolvedChildren =
+ withResolvedChildren(unresolvedTimeAdd, expressionResolver.resolve _)
val timeAddWithTypeCoercion: Expression = typeCoercionResolver
.resolve(timeAddWithResolvedChildren)
timezoneAwareExpressionResolver.withResolvedTimezone(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala
index 5ba08c0c3edb3..539202c06a8a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.analysis.resolver
-import org.apache.spark.sql.catalyst.expressions.{Expression, TimeZoneAwareExpression}
+import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, TimeZoneAwareExpression}
/**
* Resolves [[TimeZoneAwareExpressions]] by applying the session's local timezone.
@@ -32,17 +32,24 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, TimeZoneAwareExpre
class TimezoneAwareExpressionResolver(expressionResolver: TreeNodeResolver[Expression, Expression])
extends TreeNodeResolver[TimeZoneAwareExpression, Expression]
with ResolvesExpressionChildren {
+ val typeCoercionResolver = new TypeCoercionResolver(this)
/**
* Resolves a [[TimeZoneAwareExpression]] by resolving its children and applying a timezone.
+ * If the resolved expression is a [[Cast]], type coercion is not needed. Otherwise we apply
+ * [[TypeCoercionResolver]] to the resolved expression.
*
* @param unresolvedTimezoneExpression The [[TimeZoneAwareExpression]] to resolve.
- * @return A resolved [[Expression]] with the session's local timezone applied.
+ * @return A resolved [[Expression]] with the session's local timezone applied, and optionally
+ * type coerced.
*/
override def resolve(unresolvedTimezoneExpression: TimeZoneAwareExpression): Expression = {
val expressionWithResolvedChildren =
- withResolvedChildren(unresolvedTimezoneExpression, expressionResolver.resolve)
- withResolvedTimezoneCopyTags(expressionWithResolvedChildren, conf.sessionLocalTimeZone)
+ withResolvedChildren(unresolvedTimezoneExpression, expressionResolver.resolve _)
+ withResolvedTimezone(expressionWithResolvedChildren, conf.sessionLocalTimeZone) match {
+ case cast: Cast => cast
+ case other => typeCoercionResolver.resolve(other)
+ }
}
/**
@@ -55,7 +62,7 @@ class TimezoneAwareExpressionResolver(expressionResolver: TreeNodeResolver[Expre
* @param timeZoneId The timezone ID to apply.
* @return A new [[TimeZoneAwareExpression]] with the specified timezone and original tags.
*/
- def withResolvedTimezoneCopyTags(expression: Expression, timeZoneId: String): Expression =
+ def withResolvedTimezone(expression: Expression, timeZoneId: String): Expression =
expression match {
case timezoneExpression: TimeZoneAwareExpression if timezoneExpression.timeZoneId.isEmpty =>
val withTimezone = timezoneExpression.withTimeZone(timeZoneId)
@@ -63,14 +70,4 @@ class TimezoneAwareExpressionResolver(expressionResolver: TreeNodeResolver[Expre
withTimezone
case other => other
}
-
- /**
- * Apply timezone to [[TimeZoneAwareExpression]] expressions.
- */
- def withResolvedTimezone(expression: Expression, timeZoneId: String): Expression =
- expression match {
- case timezoneExpression: TimeZoneAwareExpression if timezoneExpression.timeZoneId.isEmpty =>
- timezoneExpression.withTimeZone(timeZoneId)
- case other => other
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TryExtractOrdinal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TryExtractOrdinal.scala
new file mode 100644
index 0000000000000..42766a78e248f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TryExtractOrdinal.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral}
+
+/**
+ * Try to extract ordinal from an expression. Return `Some(ordinal)` if the type of the expression
+ * is [[IntegerLitera]], `None` otherwise.
+ */
+object TryExtractOrdinal {
+ def apply(expression: Expression): Option[Int] = {
+ expression match {
+ case IntegerLiteral(literal) =>
+ Some(literal)
+ case other => None
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala
index 04089512b31b2..7f425c5708591 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala
@@ -39,8 +39,8 @@ class UnaryMinusResolver(
new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations)
override def resolve(unresolvedUnaryMinus: UnaryMinus): Expression = {
- val unaryMinusWithResolvedChildren: UnaryMinus =
- withResolvedChildren(unresolvedUnaryMinus, expressionResolver.resolve)
+ val unaryMinusWithResolvedChildren =
+ withResolvedChildren(unresolvedUnaryMinus, expressionResolver.resolve _)
typeCoercionResolver.resolve(unaryMinusWithResolvedChildren)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnresolvedCteRelationRef.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnresolvedCteRelationRef.scala
new file mode 100644
index 0000000000000..d0011557514df
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnresolvedCteRelationRef.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.analysis.{NamedRelation, UnresolvedLeafNode}
+
+/**
+ * A reference to a CTE definition in the form of an unresolved relation. This node is introduced by
+ * [[IdentifierAndCteSubstitutor]] to replace the CTE reference to avoid ineffective catalog RPC
+ * lookups in [[MetadataResolver]].
+ */
+case class UnresolvedCteRelationRef(override val name: String)
+ extends UnresolvedLeafNode
+ with NamedRelation
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnsupportedExpressionInOperatorValidation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnsupportedExpressionInOperatorValidation.scala
new file mode 100644
index 0000000000000..af093315e5860
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnsupportedExpressionInOperatorValidation.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.sql.catalyst.expressions.{
+ Expression,
+ Generator,
+ WindowExpression
+}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.{
+ Aggregate,
+ BaseEvalPythonUDTF,
+ CollectMetrics,
+ Generate,
+ LateralJoin,
+ LogicalPlan,
+ Project,
+ Window
+}
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+object UnsupportedExpressionInOperatorValidation {
+
+ /**
+ * Check that `expression` is allowed to exist in `operator`'s expression tree.
+ */
+ def isExpressionInUnsupportedOperator(expression: Expression, operator: LogicalPlan): Boolean = {
+ expression match {
+ case _: WindowExpression => operator.isInstanceOf[Window]
+ case _: AggregateExpression =>
+ !(operator.isInstanceOf[Project] ||
+ operator.isInstanceOf[Aggregate] ||
+ operator.isInstanceOf[Window] ||
+ operator.isInstanceOf[CollectMetrics] ||
+ onlyInLateralSubquery(operator))
+ case _: Generator =>
+ !(operator.isInstanceOf[Generate] ||
+ operator.isInstanceOf[BaseEvalPythonUDTF])
+ case _ =>
+ false
+ }
+ }
+
+ private def onlyInLateralSubquery(operator: LogicalPlan): Boolean = {
+ operator.isInstanceOf[LateralJoin] && {
+ // TODO: check if we are resolving a lateral join condition once lateral join is supported.
+ throw QueryCompilationErrors.unsupportedSinglePassAnalyzerFeature(
+ s"${operator.getClass} operator resolution"
+ )
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala
index e7e9c5ec822ae..f96ce534556e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver
import java.util.ArrayDeque
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, UnresolvedRelation}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, View}
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -35,6 +35,13 @@ class ViewResolver(resolver: Resolver, catalogManager: CatalogManager)
private val sourceUnresolvedRelationStack = new ArrayDeque[UnresolvedRelation]
private val viewResolutionContextStack = new ArrayDeque[ViewResolutionContext]
+ def getCatalogAndNamespace: Option[Seq[String]] =
+ if (viewResolutionContextStack.isEmpty) {
+ None
+ } else {
+ viewResolutionContextStack.peek().catalogAndNamespace
+ }
+
/**
* This method preserves the resolved [[UnresolvedRelation]] for the further view resolution
* process.
@@ -97,22 +104,28 @@ class ViewResolver(resolver: Resolver, catalogManager: CatalogManager)
*/
private def withViewResolutionContext(unresolvedView: View)(
body: => LogicalPlan): (LogicalPlan, ViewResolutionContext) = {
- val viewResolutionContext = if (viewResolutionContextStack.isEmpty()) {
- ViewResolutionContext(
- nestedViewDepth = 1,
- maxNestedViewDepth = conf.maxNestedViewDepth
+ AnalysisContext.withAnalysisContext(unresolvedView.desc) {
+ val prevContext = if (viewResolutionContextStack.isEmpty()) {
+ ViewResolutionContext(
+ nestedViewDepth = 0,
+ maxNestedViewDepth = conf.maxNestedViewDepth
+ )
+ } else {
+ viewResolutionContextStack.peek()
+ }
+
+ val viewResolutionContext = prevContext.copy(
+ nestedViewDepth = prevContext.nestedViewDepth + 1,
+ catalogAndNamespace = Some(unresolvedView.desc.viewCatalogAndNamespace)
)
- } else {
- val prevContext = viewResolutionContextStack.peek()
- prevContext.copy(nestedViewDepth = prevContext.nestedViewDepth + 1)
- }
- viewResolutionContext.validate(unresolvedView)
+ viewResolutionContext.validate(unresolvedView)
- viewResolutionContextStack.push(viewResolutionContext)
- try {
- (body, viewResolutionContext)
- } finally {
- viewResolutionContextStack.pop()
+ viewResolutionContextStack.push(viewResolutionContext)
+ try {
+ (body, viewResolutionContext)
+ } finally {
+ viewResolutionContextStack.pop()
+ }
}
}
@@ -131,8 +144,12 @@ class ViewResolver(resolver: Resolver, catalogManager: CatalogManager)
* @param nestedViewDepth Current nested view depth. Cannot exceed the `maxNestedViewDepth`.
* @param maxNestedViewDepth Maximum allowed nested view depth. Configured in the upper context
* based on [[SQLConf.MAX_NESTED_VIEW_DEPTH]].
+ * @param catalogAndNamespace Catalog and camespace under which the [[View]] was created.
*/
-case class ViewResolutionContext(nestedViewDepth: Int, maxNestedViewDepth: Int) {
+case class ViewResolutionContext(
+ nestedViewDepth: Int,
+ maxNestedViewDepth: Int,
+ catalogAndNamespace: Option[Seq[String]] = None) {
def validate(unresolvedView: View): Unit = {
if (nestedViewDepth > maxNestedViewDepth) {
throw QueryCompilationErrors.viewDepthExceedsMaxResolutionDepthError(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index fabe551d054ca..93462e5773911 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -776,6 +776,8 @@ case class UnresolvedStarWithColumns(
replacedAndExistingColumns ++ newColumns
}
+
+ override def toString: String = super[Expression].toString
}
/**
@@ -812,6 +814,8 @@ case class UnresolvedStarWithColumnsRenames(
)
}
}
+
+ override def toString: String = super[LeafExpression].toString
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
index 923373c1856a9..f2fd3b90f6468 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
@@ -207,7 +207,7 @@ object SQLFunction {
val returnType = parseReturnTypeText(props(RETURN_TYPE), isTableFunc, parser)
SQLFunction(
name = function.identifier,
- inputParam = props.get(INPUT_PARAM).map(parseTableSchema(_, parser)),
+ inputParam = props.get(INPUT_PARAM).map(parseRoutineParam(_, parser)),
returnType = returnType.get,
exprText = props.get(EXPRESSION),
queryText = props.get(QUERY),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index b123952c5f086..3eb1b35d24195 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -38,9 +38,10 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
import org.apache.spark.sql.catalyst.catalog.SQLFunction.parseDefault
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, ExpressionInfo, NamedArgumentExpression, NamedExpression, ScalarSubquery, UpCast}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, Expression, ExpressionInfo, LateralSubquery, NamedArgumentExpression, NamedExpression, OuterReference, ScalarSubquery, UpCast}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, LocalRelation, LogicalPlan, NamedParametersSupport, Project, SubqueryAlias, View}
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, LateralJoin, LocalRelation, LogicalPlan, NamedParametersSupport, OneRowRelation, Project, SubqueryAlias, View}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils}
import org.apache.spark.sql.connector.catalog.CatalogManager
@@ -1675,6 +1676,86 @@ class SessionCatalog(
}
}
+ /**
+ * Constructs a SQL table function plan.
+ * This function should be invoked with the captured SQL configs from the function.
+ *
+ * Example SQL table function:
+ *
+ * CREATE FUNCTION foo(x INT) RETURNS TABLE(a INT) RETURN SELECT x + 1 AS x1
+ *
+ * Query:
+ *
+ * SELECT * FROM foo(1);
+ *
+ * Plan:
+ *
+ * Project [CAST(x1 AS INT) AS a]
+ * +- LateralJoin lateral-subquery [x]
+ * : +- Project [(outer(x) + 1) AS x1]
+ * : +- OneRowRelation
+ * +- Project [CAST(1 AS INT) AS x]
+ * +- OneRowRelation
+ */
+ def makeSQLTableFunctionPlan(
+ name: String,
+ function: SQLFunction,
+ input: Seq[Expression],
+ outputAttrs: Seq[Attribute]): LogicalPlan = {
+ assert(function.isTableFunc)
+ val funcName = function.name.funcName
+ val inputParam = function.inputParam
+ val returnParam = function.getTableFuncReturnCols
+ val (_, query) = function.getExpressionAndQuery(parser, isTableFunc = true)
+ assert(query.isDefined)
+
+ // Check function arguments
+ val paramSize = inputParam.map(_.size).getOrElse(0)
+ if (input.size > paramSize) {
+ throw QueryCompilationErrors.wrongNumArgsError(
+ name, paramSize.toString, input.size)
+ }
+
+ val body = if (inputParam.isDefined) {
+ val param = inputParam.get
+ // Attributes referencing the input parameters inside the function can use the
+ // function name as a qualifier.
+ val qualifier = Seq(funcName)
+ val paddedInput = input ++
+ param.takeRight(paramSize - input.size).map { p =>
+ val defaultExpr = p.getDefault()
+ if (defaultExpr.isDefined) {
+ parseDefault(defaultExpr.get, parser)
+ } else {
+ throw QueryCompilationErrors.wrongNumArgsError(
+ name, paramSize.toString, input.size)
+ }
+ }
+
+ val inputCast = paddedInput.zip(param.fields).map {
+ case (expr, param) =>
+ // Add outer references to all attributes in the function input.
+ val outer = expr.transform {
+ case a: Attribute => OuterReference(a)
+ }
+ Alias(Cast(outer, param.dataType), param.name)(qualifier = qualifier)
+ }
+ val inputPlan = Project(inputCast, OneRowRelation())
+ LateralJoin(inputPlan, LateralSubquery(query.get), Inner, None)
+ } else {
+ query.get
+ }
+
+ assert(returnParam.length == outputAttrs.length)
+ val output = returnParam.fields.zipWithIndex.map { case (param, i) =>
+ // Since we cannot get the output of a unresolved logical plan, we need
+ // to reference the output column of the lateral join by its position.
+ val child = Cast(GetColumnByOrdinal(paramSize + i, param.dataType), param.dataType)
+ Alias(child, param.name)(exprId = outputAttrs(i).exprId)
+ }
+ SQLFunctionNode(function, SubqueryAlias(funcName, Project(output.toSeq, body)))
+ }
+
/**
* Constructs a [[FunctionBuilder]] based on the provided function metadata.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
index a76ca7b15c278..8ed2414683522 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
@@ -86,6 +86,11 @@ object UserDefinedFunction {
// The default Hive Metastore SQL schema length for function resource uri.
private val HIVE_FUNCTION_RESOURCE_URI_LENGTH_THRESHOLD: Int = 4000
+ def parseRoutineParam(text: String, parser: ParserInterface): StructType = {
+ val parsed = parser.parseRoutineParam(text)
+ CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
+ }
+
def parseTableSchema(text: String, parser: ParserInterface): StructType = {
val parsed = parser.parseTableSchema(text)
CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 1cb3520d4e265..d92d2881445ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -36,7 +36,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CurrentUserContext, FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, Resolver, SchemaBinding, SchemaCompensation, SchemaEvolution, SchemaTypeEvolution, SchemaUnsupported, UnresolvedLeafNode, ViewSchemaMode}
+import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NormalizeableRelation, Resolver, SchemaBinding, SchemaCompensation, SchemaEvolution, SchemaTypeEvolution, SchemaUnsupported, UnresolvedLeafNode, ViewSchemaMode}
import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -156,7 +156,7 @@ case class CatalogStorageFormat(
def toJsonLinkedHashMap: mutable.LinkedHashMap[String, JValue] = {
val map = mutable.LinkedHashMap[String, JValue]()
- locationUri.foreach(l => map += ("Location" -> JString(l.toString)))
+ locationUri.foreach(l => map += ("Location" -> JString(CatalogUtils.URIToString(l))))
serde.foreach(s => map += ("Serde Library" -> JString(s)))
inputFormat.foreach(format => map += ("InputFormat" -> JString(format)))
outputFormat.foreach(format => map += ("OutputFormat" -> JString(format)))
@@ -1082,7 +1082,7 @@ case class HiveTableRelation(
partitionCols: Seq[AttributeReference],
tableStats: Option[Statistics] = None,
@transient prunedPartitions: Option[Seq[CatalogTablePartition]] = None)
- extends LeafNode with MultiInstanceRelation {
+ extends LeafNode with MultiInstanceRelation with NormalizeableRelation {
assert(tableMeta.identifier.database.isDefined)
assert(DataTypeUtils.sameType(tableMeta.partitionSchema, partitionCols.toStructType))
assert(DataTypeUtils.sameType(tableMeta.dataSchema, dataCols.toStructType))
@@ -1150,4 +1150,11 @@ case class HiveTableRelation(
val metadataStr = truncatedString(metadataEntries, "[", ", ", "]", maxFields)
s"$nodeName $metadataStr"
}
+
+ /**
+ * Minimally normalizes this [[HiveTableRelation]] to make it comparable in [[NormalizePlan]].
+ */
+ override def normalize(): LogicalPlan = {
+ copy(tableMeta = CatalogTable.normalize(tableMeta))
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala
index 47e2e288357e1..1c6eecad170f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala
@@ -52,6 +52,12 @@ class CSVHeaderChecker(
// the column name don't conform to the schema, an exception is thrown.
private val enforceSchema = options.enforceSchema
+ // When `options.singleVariantColumn` is defined, it will be set to the header column
+ // names and no check will happen (because any name is valid).
+ private var headerColumnNames: Option[Array[String]] = None
+ // See `CSVDataSource.setHeaderForSingleVariantColumn` for details.
+ var setHeaderForSingleVariantColumn: Option[Option[Array[String]] => Unit] = None
+
/**
* Checks that column names in a CSV header and field names in the schema are the same
* by taking into account case sensitivity.
@@ -60,6 +66,11 @@ class CSVHeaderChecker(
*/
private def checkHeaderColumnNames(columnNames: Array[String]): Unit = {
if (columnNames != null) {
+ if (options.singleVariantColumn.isDefined) {
+ headerColumnNames = Some(columnNames)
+ return
+ }
+
val fieldNames = schema.map(_.name).toIndexedSeq
val (headerLen, schemaSize) = (columnNames.length, fieldNames.length)
var errorMessage: Option[MessageWithContext] = None
@@ -122,6 +133,7 @@ class CSVHeaderChecker(
val firstRecord = tokenizer.parseNext()
checkHeaderColumnNames(firstRecord)
}
+ setHeaderForSingleVariantColumn.foreach(f => f(headerColumnNames))
}
// This is currently only used to parse CSV with non-multiLine mode.
@@ -137,5 +149,6 @@ class CSVHeaderChecker(
checkHeaderColumnNames(tokenizer.parseLine(header))
}
}
+ setHeaderForSingleVariantColumn.foreach(f => f(headerColumnNames))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
index 6c68bc1aa5890..b43f124ed19c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
@@ -306,6 +306,14 @@ class CSVOptions(
private val isColumnPruningOptionEnabled: Boolean =
getBool(COLUMN_PRUNING, !multiLine && columnPruning)
+ // This option takes in a column name and specifies that the entire CSV record should be stored
+ // as a single VARIANT type column in the table with the given column name.
+ // E.g. spark.read.format("csv").option("singleVariantColumn", "colName")
+ val singleVariantColumn: Option[String] = parameters.get(SINGLE_VARIANT_COLUMN)
+
+ def needHeaderForSingleVariantColumn: Boolean =
+ singleVariantColumn.isDefined && headerFlag
+
def asWriterSettings: CsvWriterSettings = {
val writerSettings = new CsvWriterSettings()
val format = writerSettings.getFormat
@@ -388,7 +396,7 @@ object CSVOptions extends DataSourceOptions {
val EMPTY_VALUE = newOption("emptyValue")
val LINE_SEP = newOption("lineSep")
val INPUT_BUFFER_SIZE = newOption("inputBufferSize")
- val COLUMN_NAME_OF_CORRUPT_RECORD = newOption("columnNameOfCorruptRecord")
+ val COLUMN_NAME_OF_CORRUPT_RECORD = newOption(DataSourceOptions.COLUMN_NAME_OF_CORRUPT_RECORD)
val NULL_VALUE = newOption("nullValue")
val NAN_VALUE = newOption("nanValue")
val POSITIVE_INF = newOption("positiveInf")
@@ -407,4 +415,5 @@ object CSVOptions extends DataSourceOptions {
val DELIMITER = "delimiter"
newOption(SEP, DELIMITER)
val COLUMN_PRUNING = newOption("columnPruning")
+ val SINGLE_VARIANT_COLUMN = newOption(DataSourceOptions.SINGLE_VARIANT_COLUMN)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
index 0fd0601803a6a..38400ec362750 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.csv
import java.io.InputStream
+import java.util.Locale
+import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.univocity.parsers.common.TextParsingException
@@ -34,7 +36,8 @@ import org.apache.spark.sql.errors.{ExecutionErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.types.variant._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
/**
* Constructs a parser for a given schema that translates CSV data to an [[InternalRow]].
@@ -135,6 +138,11 @@ class UnivocityParser(
options.dateFormatOption.isEmpty
}
+ // When `options.needHeaderForSingleVariantColumn` is true, it will be set to the header column
+ // names by `CSVDataSource.readHeaderForSingleVariantColumn`.
+ var headerColumnNames: Option[Array[String]] = None
+ private val singleVariantFieldConverters = new ArrayBuffer[VariantValueConverter]()
+
// Retrieve the raw record string.
private def getCurrentInput: UTF8String = {
if (tokenizer.getContext == null) return null
@@ -161,9 +169,12 @@ class UnivocityParser(
// Each input token is placed in each output row's position by mapping these. In this case,
//
// output row - ["A", 2]
- private val valueConverters: Array[ValueConverter] = {
- requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable)).toArray
- }
+ private val valueConverters: Array[ValueConverter] =
+ if (options.singleVariantColumn.isDefined) {
+ null
+ } else {
+ requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable)).toArray
+ }
private val decimalParser = ExprUtils.getDecimalParser(options.locale)
@@ -274,6 +285,8 @@ class UnivocityParser(
UTF8String.fromString(datum), dt.startField, dt.endField)
}
+ case _: VariantType => new VariantValueConverter
+
case udt: UserDefinedType[_] =>
makeConverter(name, udt.sqlType, nullable)
@@ -331,6 +344,46 @@ class UnivocityParser(
(tokens: Array[String], index: Int) => tokens(tokenIndexArr(index))
}
+ /**
+ * The entire line of CSV data is collected into a single variant object. When `headerColumnNames`
+ * is defined, the field names will be extracted from it. Otherwise, the field names will have a
+ * a format of "_c$i" to match the position of the values in the CSV data.
+ */
+ protected final def convertSingleVariantRow(
+ tokens: Array[String],
+ currentInput: UTF8String): GenericInternalRow = {
+ val row = new GenericInternalRow(1)
+ try {
+ val keys = headerColumnNames.orNull
+ val numFields = if (keys != null) tokens.length.min(keys.length) else tokens.length
+ if (singleVariantFieldConverters.length < numFields) {
+ val extra = numFields - singleVariantFieldConverters.length
+ singleVariantFieldConverters.appendAll(Array.fill(extra)(new VariantValueConverter))
+ }
+ val builder = new VariantBuilder(false)
+ val start = builder.getWritePos
+ val fields = new java.util.ArrayList[VariantBuilder.FieldEntry](numFields)
+ for (i <- 0 until numFields) {
+ val key = if (keys != null) keys(i) else "_c" + i
+ val id = builder.addKey(key)
+ fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start))
+ singleVariantFieldConverters(i).convertInput(builder, tokens(i))
+ }
+ builder.finishWritingObject(start, fields)
+ val v = builder.result()
+ row(0) = new VariantVal(v.getValue, v.getMetadata)
+ // If the header line has different number of tokens than the content line, the CSV data is
+ // malformed. We may still have partially parsed data in `row`.
+ if (keys != null && keys.length != tokens.length) {
+ throw QueryExecutionErrors.malformedCSVRecordError(currentInput.toString)
+ }
+ row
+ } catch {
+ case NonFatal(e) =>
+ throw BadRecordException(() => currentInput, () => Array(row), cause = e)
+ }
+ }
+
private def convert(tokens: Array[String]): Option[InternalRow] = {
if (tokens == null) {
throw BadRecordException(
@@ -341,6 +394,10 @@ class UnivocityParser(
val currentInput = getCurrentInput
+ if (options.singleVariantColumn.isDefined) {
+ return Some(convertSingleVariantRow(tokens, currentInput))
+ }
+
var badRecordException: Option[Throwable] = if (tokens.length != parsedSchema.length) {
// If the number of tokens doesn't match the schema, we should treat it as a malformed record.
// However, we still have chance to parse some of the tokens. It continues to parses the
@@ -386,6 +443,132 @@ class UnivocityParser(
}
}
}
+
+ /**
+ * This class converts a comma-separated value into a variant column (when the schema contains
+ * variant type) or a variant field (when in singleVariantColumn mode).
+ *
+ * It has a list of scalar types to try (long, decimal, date, timestamp, boolean) and maintains
+ * the current content type. It tries to parse the input as the current content type. If the
+ * parsing fails, it moves to the next type in the list and continues the trial. It never checks
+ * the previous types that have already failed. In the end, it either successfully parses the
+ * input as a specific scalar type, or fails after trying all the types and defaults to the string
+ * type. The state is reset for every input file.
+ *
+ * Floating point types (double, float) are not considered to avoid precision loss.
+ */
+ private final class VariantValueConverter extends ValueConverter {
+ private var currentType: DataType = LongType
+ // Keep consistent with `CSVInferSchema`: only produce TimestampNTZ when the default timestamp
+ // type is TimestampNTZ.
+ private val isDefaultNTZ = SQLConf.get.timestampType == TimestampNTZType
+
+ override def apply(s: String): Any = {
+ val builder = new VariantBuilder(false)
+ convertInput(builder, s)
+ val v = builder.result()
+ new VariantVal(v.getValue, v.getMetadata)
+ }
+
+ def convertInput(builder: VariantBuilder, s: String): Unit = {
+ if (s == null || s == options.nullValue) {
+ builder.appendNull()
+ return
+ }
+
+ def parseLong(): DataType = {
+ try {
+ builder.appendLong(s.toLong)
+ // The actual integral type doesn't matter. `appendLong` will use the smallest possible
+ // integral type to store the value.
+ LongType
+ } catch {
+ case NonFatal(_) => parseDecimal()
+ }
+ }
+
+ def parseDecimal(): DataType = {
+ try {
+ var d = decimalParser(s)
+ if (d.scale() < 0) {
+ d = d.setScale(0)
+ }
+ if (d.scale() <= VariantUtil.MAX_DECIMAL16_PRECISION &&
+ d.precision() <= VariantUtil.MAX_DECIMAL16_PRECISION) {
+ builder.appendDecimal(d)
+ // The actual decimal type doesn't matter. `appendDecimal` will use the smallest
+ // possible decimal type to store the value.
+ DecimalType.USER_DEFAULT
+ } else {
+ if (options.preferDate) parseDate() else parseTimestampNTZ()
+ }
+ } catch {
+ case NonFatal(_) =>
+ if (options.preferDate) parseDate() else parseTimestampNTZ()
+ }
+ }
+
+ def parseDate(): DataType = {
+ try {
+ builder.appendDate(dateFormatter.parse(s))
+ DateType
+ } catch {
+ case NonFatal(_) => parseTimestampNTZ()
+ }
+ }
+
+ def parseTimestampNTZ(): DataType = {
+ if (isDefaultNTZ) {
+ try {
+ builder.appendTimestampNtz(timestampNTZFormatter.parseWithoutTimeZone(s, false))
+ TimestampNTZType
+ } catch {
+ case NonFatal(_) => parseTimestamp()
+ }
+ } else {
+ parseTimestamp()
+ }
+ }
+
+ def parseTimestamp(): DataType = {
+ try {
+ builder.appendTimestamp(timestampFormatter.parse(s))
+ TimestampType
+ } catch {
+ case NonFatal(_) => parseBoolean()
+ }
+ }
+
+ def parseBoolean(): DataType = {
+ val lower = s.toLowerCase(Locale.ROOT)
+ if (lower == "true") {
+ builder.appendBoolean(true)
+ BooleanType
+ } else if (lower == "false") {
+ builder.appendBoolean(false)
+ BooleanType
+ } else {
+ parseString()
+ }
+ }
+
+ def parseString(): DataType = {
+ builder.appendString(s)
+ StringType
+ }
+
+ val newType = currentType match {
+ case LongType => parseLong()
+ case _: DecimalType => parseDecimal()
+ case DateType => parseDate()
+ case TimestampNTZType => parseTimestampNTZ()
+ case TimestampType => parseTimestamp()
+ case BooleanType => parseBoolean()
+ case StringType => parseString()
+ }
+ currentType = newType
+ }
+ }
}
private[sql] object UnivocityParser {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index f1904c2436ab8..00bde9f8c1f72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -24,6 +24,7 @@ import scala.language.implicitConversions
import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -64,7 +65,7 @@ import org.apache.spark.unsafe.types.UTF8String
* LocalRelation [key#2,value#3], []
* }}}
*/
-package object dsl {
+package object dsl extends SQLConfHelper {
trait ImplicitOperators {
def expr: Expression
@@ -402,6 +403,8 @@ package object dsl {
def localLimit(limitExpr: Expression): LogicalPlan = LocalLimit(limitExpr, logicalPlan)
+ def globalLimit(limitExpr: Expression): LogicalPlan = GlobalLimit(limitExpr, logicalPlan)
+
def offset(offsetExpr: Expression): LogicalPlan = Offset(offsetExpr, logicalPlan)
def join(
@@ -444,11 +447,16 @@ package object dsl {
def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan)
def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = {
+ // Replace top-level integer literals with ordinals, if `groupByOrdinal` is enabled.
+ val groupingExpressionsWithOrdinals = groupingExprs.map {
+ case Literal(value: Int, IntegerType) if conf.groupByOrdinal => UnresolvedOrdinal(value)
+ case other => other
+ }
val aliasedExprs = aggregateExprs.map {
case ne: NamedExpression => ne
case e => UnresolvedAlias(e)
}
- Aggregate(groupingExprs, aliasedExprs, logicalPlan)
+ Aggregate(groupingExpressionsWithOrdinals, aliasedExprs, logicalPlan)
}
def having(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
index 81743251bada9..0de459e1196d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
@@ -18,14 +18,39 @@ package org.apache.spark.sql.catalyst.encoders
import scala.collection.Map
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, TimeType, UserDefinedType, VariantType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
+/**
+ * :: DeveloperApi ::
+ * Extensible [[AgnosticEncoder]] providing conversion extension points over type T
+ * @tparam T over T
+ */
+@DeveloperApi
+@deprecated("This trait is intended only as a migration tool and will be removed in 4.1")
+trait AgnosticExpressionPathEncoder[T]
+ extends AgnosticEncoder[T] {
+ /**
+ * Converts from T to InternalRow
+ * @param input the starting input path
+ * @return
+ */
+ def toCatalyst(input: Expression): Expression
+
+ /**
+ * Converts from InternalRow to T
+ * @param inputPath path expression from InternalRow
+ * @return
+ */
+ def fromCatalyst(inputPath: Expression): Expression
+}
+
/**
* Helper class for Generating [[ExpressionEncoder]]s.
*/
@@ -77,6 +102,7 @@ object EncoderUtils {
case _: DecimalType => classOf[Decimal]
case _: DayTimeIntervalType => classOf[PhysicalLongType.InternalType]
case _: YearMonthIntervalType => classOf[PhysicalIntegerType.InternalType]
+ case _: TimeType => classOf[PhysicalLongType.InternalType]
case _: StringType => classOf[UTF8String]
case _: StructType => classOf[InternalRow]
case _: ArrayType => classOf[ArrayData]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index b92acfb5b0f3a..084be5a350459 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, JavaTypeInference, ScalaReflection, SerializerBuildHelper}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{OptionEncoder, TransformingEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
@@ -215,6 +216,13 @@ case class ExpressionEncoder[T](
StructField(s.name, s.dataType, s.nullable)
})
+ private def transformerOfOption(enc: AgnosticEncoder[_]): Boolean =
+ enc match {
+ case t: TransformingEncoder[_, _] => transformerOfOption(t.transformed)
+ case _: OptionEncoder[_] => true
+ case _ => false
+ }
+
/**
* Returns true if the type `T` is serialized as a struct by `objSerializer`.
*/
@@ -228,7 +236,8 @@ case class ExpressionEncoder[T](
* returns true if `T` is serialized as struct and is not `Option` type.
*/
def isSerializedAsStructForTopLevel: Boolean = {
- isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)
+ isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) &&
+ !transformerOfOption(encoder)
}
// serializer expressions are used to encode an object to a row, while the object is usually an
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8773d7a6a029e..7a4145933fc7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -119,6 +119,7 @@ object Cast extends QueryErrorsBase {
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
case (_: StringType, DateType) => true
+ case (_: StringType, _: TimeType) => true
case (TimestampType, DateType) => true
case (TimestampNTZType, DateType) => true
@@ -133,19 +134,23 @@ object Cast extends QueryErrorsBase {
// to convert data of these types to Variant Objects.
case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false)
+ // non-null variants can generate nulls even in ANSI mode
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
- canAnsiCast(fromType, toType) && resolvableNullability(fn, tn)
+ canAnsiCast(fromType, toType) && resolvableNullability(fn || (fromType == VariantType), tn)
+ // non-null variants can generate nulls even in ANSI mode
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
canAnsiCast(fromKey, toKey) && canAnsiCast(fromValue, toValue) &&
- resolvableNullability(fn, tn)
+ resolvableNullability(fn || (fromValue == VariantType), tn)
+ // non-null variants can generate nulls even in ANSI mode
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
canAnsiCast(fromField.dataType, toField.dataType) &&
- resolvableNullability(fromField.nullable, toField.nullable)
+ resolvableNullability(fromField.nullable || (fromField.dataType == VariantType),
+ toField.nullable)
}
case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true
@@ -219,6 +224,7 @@ object Cast extends QueryErrorsBase {
case (TimestampType, TimestampNTZType) => true
case (_: StringType, DateType) => true
+ case (_: StringType, _: TimeType) => true
case (TimestampType, DateType) => true
case (TimestampNTZType, DateType) => true
@@ -727,6 +733,15 @@ case class Cast(
buildCast[Long](_, t => microsToDays(t, ZoneOffset.UTC))
}
+ private[this] def castToTime(from: DataType): Any => Any = from match {
+ case _: StringType =>
+ if (ansiEnabled) {
+ buildCast[UTF8String](_, s => DateTimeUtils.stringToTimeAnsi(s, getContextOrNull()))
+ } else {
+ buildCast[UTF8String](_, s => DateTimeUtils.stringToTime(s).orNull)
+ }
+ }
+
// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case _: StringType =>
@@ -1134,6 +1149,7 @@ case class Cast(
case s: StringType => castToString(from, s.constraint)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
+ case _: TimeType => castToTime(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case TimestampNTZType => castToTimestampNTZ(from)
@@ -1241,6 +1257,7 @@ case class Cast(
(c, evPrim, _) => castToStringCode(from, ctx, s.constraint).apply(c, evPrim)
case BinaryType => castToBinaryCode(from)
case DateType => castToDateCode(from, ctx)
+ case _: TimeType => castToTimeCode(from, ctx)
case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
case TimestampType => castToTimestampCode(from, ctx)
case TimestampNTZType => castToTimestampNTZCode(from, ctx)
@@ -1313,8 +1330,7 @@ case class Cast(
"""
} else {
code"""
- scala.Option $intOpt =
- org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c);
+ scala.Option $intOpt = $dateTimeUtilsCls.stringToDate($c);
if ($intOpt.isDefined()) {
$evPrim = ((Integer) $intOpt.get()).intValue();
} else {
@@ -1327,8 +1343,7 @@ case class Cast(
val zidClass = classOf[ZoneId]
val zid = JavaCode.global(ctx.addReferenceObj("zoneId", zoneId, zidClass.getName), zidClass)
(c, evPrim, evNull) =>
- code"""$evPrim =
- org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToDays($c, $zid);"""
+ code"""$evPrim = $dateTimeUtilsCls.microsToDays($c, $zid);"""
case TimestampNTZType =>
(c, evPrim, evNull) =>
code"$evPrim = $dateTimeUtilsCls.microsToDays($c, java.time.ZoneOffset.UTC);"
@@ -1337,6 +1352,34 @@ case class Cast(
}
}
+ private[this] def castToTimeCode(
+ from: DataType,
+ ctx: CodegenContext): CastFunction = {
+ from match {
+ case _: StringType =>
+ val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
+ (c, evPrim, evNull) =>
+ if (ansiEnabled) {
+ val errorContext = getContextOrNullCode(ctx)
+ code"""
+ $evPrim = $dateTimeUtilsCls.stringToTimeAnsi($c, $errorContext);
+ """
+ } else {
+ code"""
+ scala.Option $longOpt = $dateTimeUtilsCls.stringToTime($c);
+ if ($longOpt.isDefined()) {
+ $evPrim = ((Long) $longOpt.get()).longValue();
+ } else {
+ $evNull = true;
+ }
+ """
+ }
+
+ case _ =>
+ (_, _, evNull) => code"$evNull = true;"
+ }
+ }
+
private[this] def changePrecision(
d: ExprValue,
decimalType: DecimalType,
@@ -1481,8 +1524,7 @@ case class Cast(
"""
} else {
code"""
- scala.Option $longOpt =
- org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $zid);
+ scala.Option $longOpt = $dateTimeUtilsCls.stringToTimestamp($c, $zid);
if ($longOpt.isDefined()) {
$evPrim = ((Long) $longOpt.get()).longValue();
} else {
@@ -1500,8 +1542,7 @@ case class Cast(
ctx.addReferenceObj("zoneId", zoneId, zoneIdClass.getName),
zoneIdClass)
(c, evPrim, evNull) =>
- code"""$evPrim =
- org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMicros($c, $zid);"""
+ code"""$evPrim = $dateTimeUtilsCls.daysToMicros($c, $zid);"""
case TimestampNTZType =>
val zoneIdClass = classOf[ZoneId]
val zid = JavaCode.global(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala
index 62a1afecfd7f0..4a074bb3039b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala
@@ -38,7 +38,7 @@ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging {
def createObject(in: IN): OUT = {
// We are allowed to choose codegen-only or no-codegen modes if under tests.
- val fallbackMode = CodegenObjectFactoryMode.withName(SQLConf.get.codegenFactoryMode)
+ val fallbackMode = SQLConf.get.codegenFactoryMode
fallbackMode match {
case CodegenObjectFactoryMode.CODEGEN_ONLY =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 7d4f8c3b2564f..8994722047caa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -800,9 +800,13 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes wi
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
+ protected def sameType(left: DataType, right: DataType): Boolean = {
+ DataTypeUtils.sameType(left, right)
+ }
+
override def checkInputDataTypes(): TypeCheckResult = {
// First check whether left and right have the same type, then check if the type is acceptable.
- if (!DataTypeUtils.sameType(left.dataType, right.dataType)) {
+ if (!sameType(left.dataType, right.dataType)) {
DataTypeMismatch(
errorSubClass = "BINARY_OP_DIFF_TYPES",
messageParameters = Map(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
index fb6ebc899d8fa..971cfcae8e478 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
@@ -196,7 +196,7 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen
// We use INT for DATE and YearMonthIntervalType internally
case IntegerType | DateType | _: YearMonthIntervalType => new MutableInt
// We use Long for Timestamp, Timestamp without time zone and DayTimeInterval internally
- case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
+ case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType =>
new MutableLong
case FloatType => new MutableFloat
case DoubleType => new MutableDouble
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
index de72b94df3ac5..2e649763a9ac9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
@@ -22,7 +22,7 @@ import java.time.ZoneOffset
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateFormatter, IntervalStringStyles, IntervalUtils, MapData, SparkStringUtils, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateFormatter, FractionTimeFormatter, IntervalStringStyles, IntervalUtils, MapData, SparkStringUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle
@@ -34,6 +34,7 @@ import org.apache.spark.util.ArrayImplicits._
trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
private lazy val dateFormatter = DateFormatter()
+ private lazy val timeFormatter = new FractionTimeFormatter()
private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId)
private lazy val timestampNTZFormatter = TimestampFormatter.getFractionFormatter(ZoneOffset.UTC)
@@ -73,6 +74,8 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
acceptAny[Long](t => UTF8String.fromString(timestampFormatter.format(t)))
case TimestampNTZType =>
acceptAny[Long](t => UTF8String.fromString(timestampNTZFormatter.format(t)))
+ case _: TimeType =>
+ acceptAny[Long](t => UTF8String.fromString(timeFormatter.format(t)))
case ArrayType(et, _) =>
acceptAny[ArrayData](array => {
val builder = new UTF8StringBuilder
@@ -224,6 +227,11 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
ctx.addReferenceObj("timestampNTZFormatter", timestampNTZFormatter),
timestampNTZFormatter.getClass)
(c, evPrim) => code"$evPrim = UTF8String.fromString($tf.format($c));"
+ case _: TimeType =>
+ val tf = JavaCode.global(
+ ctx.addReferenceObj("timeFormatter", timeFormatter),
+ timeFormatter.getClass)
+ (c, evPrim) => code"$evPrim = UTF8String.fromString($tf.format($c));"
case CalendarIntervalType =>
(c, evPrim) => code"$evPrim = UTF8String.fromString($c.toString());"
case ArrayType(et, _) =>
@@ -449,7 +457,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression =>
object ToStringBase {
def getBinaryFormatter: BinaryFormatter = {
val style = SQLConf.get.getConf(SQLConf.BINARY_OUTPUT_STYLE)
- style.map(BinaryOutputStyle.withName) match {
+ style match {
case Some(BinaryOutputStyle.UTF8) =>
(array: Array[Byte]) => UTF8String.fromBytes(array)
case Some(BinaryOutputStyle.BASIC) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index d14c8cb675387..7cc03f3ac3fa6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -23,14 +23,15 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{FUNCTION_NAME, FUNCTION_PARAM}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
-import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
+import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
-import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
+import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
+import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
@@ -205,4 +206,171 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
None
}
}
+
+ def toCatalyst(expr: V2Expression): Option[Expression] = expr match {
+ case _: AlwaysTrue => Some(Literal.TrueLiteral)
+ case _: AlwaysFalse => Some(Literal.FalseLiteral)
+ case l: V2Literal[_] => Some(Literal(l.value, l.dataType))
+ case r: NamedReference => Some(UnresolvedAttribute(r.fieldNames.toImmutableArraySeq))
+ case c: V2Cast => toCatalyst(c.expression).map(Cast(_, c.dataType, ansiEnabled = true))
+ case e: GeneralScalarExpression => convertScalarExpr(e)
+ case _ => None
+ }
+
+ private def convertScalarExpr(expr: GeneralScalarExpression): Option[Expression] = {
+ convertPredicate(expr)
+ .orElse(convertConditionalFunc(expr))
+ .orElse(convertMathFunc(expr))
+ .orElse(convertBitwiseFunc(expr))
+ .orElse(convertTrigonometricFunc(expr))
+ }
+
+ private def convertPredicate(expr: GeneralScalarExpression): Option[Expression] = {
+ expr.name match {
+ case "IS_NULL" => convertUnaryExpr(expr, IsNull)
+ case "IS_NOT_NULL" => convertUnaryExpr(expr, IsNotNull)
+ case "NOT" => convertUnaryExpr(expr, Not)
+ case "=" => convertBinaryExpr(expr, EqualTo)
+ case "<=>" => convertBinaryExpr(expr, EqualNullSafe)
+ case ">" => convertBinaryExpr(expr, GreaterThan)
+ case ">=" => convertBinaryExpr(expr, GreaterThanOrEqual)
+ case "<" => convertBinaryExpr(expr, LessThan)
+ case "<=" => convertBinaryExpr(expr, LessThanOrEqual)
+ case "<>" => convertBinaryExpr(expr, (left, right) => Not(EqualTo(left, right)))
+ case "AND" => convertBinaryExpr(expr, And)
+ case "OR" => convertBinaryExpr(expr, Or)
+ case "STARTS_WITH" => convertBinaryExpr(expr, StartsWith)
+ case "ENDS_WITH" => convertBinaryExpr(expr, EndsWith)
+ case "CONTAINS" => convertBinaryExpr(expr, Contains)
+ case "IN" => convertExpr(expr, children => In(children.head, children.tail))
+ case _ => None
+ }
+ }
+
+ private def convertConditionalFunc(expr: GeneralScalarExpression): Option[Expression] = {
+ expr.name match {
+ case "CASE_WHEN" =>
+ convertExpr(expr, children =>
+ if (children.length % 2 == 0) {
+ val branches = children.grouped(2).map { case Seq(c, v) => (c, v) }.toSeq
+ CaseWhen(branches, None)
+ } else {
+ val (pairs, last) = children.splitAt(children.length - 1)
+ val branches = pairs.grouped(2).map { case Seq(c, v) => (c, v) }.toSeq
+ CaseWhen(branches, Some(last.head))
+ })
+ case _ => None
+ }
+ }
+
+ private def convertMathFunc(expr: GeneralScalarExpression): Option[Expression] = {
+ expr.name match {
+ case "+" => convertBinaryExpr(expr, Add(_, _, evalMode = EvalMode.ANSI))
+ case "-" =>
+ if (expr.children.length == 1) {
+ convertUnaryExpr(expr, UnaryMinus(_, failOnError = true))
+ } else if (expr.children.length == 2) {
+ convertBinaryExpr(expr, Subtract(_, _, evalMode = EvalMode.ANSI))
+ } else {
+ None
+ }
+ case "*" => convertBinaryExpr(expr, Multiply(_, _, evalMode = EvalMode.ANSI))
+ case "/" => convertBinaryExpr(expr, Divide(_, _, evalMode = EvalMode.ANSI))
+ case "%" => convertBinaryExpr(expr, Remainder(_, _, evalMode = EvalMode.ANSI))
+ case "ABS" => convertUnaryExpr(expr, Abs(_, failOnError = true))
+ case "COALESCE" => convertExpr(expr, Coalesce)
+ case "GREATEST" => convertExpr(expr, Greatest)
+ case "LEAST" => convertExpr(expr, Least)
+ case "RAND" =>
+ if (expr.children.isEmpty) {
+ Some(new Rand())
+ } else if (expr.children.length == 1) {
+ convertUnaryExpr(expr, new Rand(_))
+ } else {
+ None
+ }
+ case "LOG" => convertBinaryExpr(expr, Logarithm)
+ case "LOG10" => convertUnaryExpr(expr, Log10)
+ case "LOG2" => convertUnaryExpr(expr, Log2)
+ case "LN" => convertUnaryExpr(expr, Log)
+ case "EXP" => convertUnaryExpr(expr, Exp)
+ case "POWER" => convertBinaryExpr(expr, Pow)
+ case "SQRT" => convertUnaryExpr(expr, Sqrt)
+ case "FLOOR" => convertUnaryExpr(expr, Floor)
+ case "CEIL" => convertUnaryExpr(expr, Ceil)
+ case "ROUND" => convertBinaryExpr(expr, Round(_, _, ansiEnabled = true))
+ case "CBRT" => convertUnaryExpr(expr, Cbrt)
+ case "DEGREES" => convertUnaryExpr(expr, ToDegrees)
+ case "RADIANS" => convertUnaryExpr(expr, ToRadians)
+ case "SIGN" => convertUnaryExpr(expr, Signum)
+ case "WIDTH_BUCKET" =>
+ convertExpr(
+ expr,
+ children => WidthBucket(children(0), children(1), children(2), children(3)))
+ case _ => None
+ }
+ }
+
+ private def convertTrigonometricFunc(expr: GeneralScalarExpression): Option[Expression] = {
+ expr.name match {
+ case "SIN" => convertUnaryExpr(expr, Sin)
+ case "SINH" => convertUnaryExpr(expr, Sinh)
+ case "COS" => convertUnaryExpr(expr, Cos)
+ case "COSH" => convertUnaryExpr(expr, Cosh)
+ case "TAN" => convertUnaryExpr(expr, Tan)
+ case "TANH" => convertUnaryExpr(expr, Tanh)
+ case "COT" => convertUnaryExpr(expr, Cot)
+ case "ASIN" => convertUnaryExpr(expr, Asin)
+ case "ASINH" => convertUnaryExpr(expr, Asinh)
+ case "ACOS" => convertUnaryExpr(expr, Acos)
+ case "ACOSH" => convertUnaryExpr(expr, Acosh)
+ case "ATAN" => convertUnaryExpr(expr, Atan)
+ case "ATANH" => convertUnaryExpr(expr, Atanh)
+ case "ATAN2" => convertBinaryExpr(expr, Atan2)
+ case _ => None
+ }
+ }
+
+ private def convertBitwiseFunc(expr: GeneralScalarExpression): Option[Expression] = {
+ expr.name match {
+ case "~" => convertUnaryExpr(expr, BitwiseNot)
+ case "&" => convertBinaryExpr(expr, BitwiseAnd)
+ case "|" => convertBinaryExpr(expr, BitwiseOr)
+ case "^" => convertBinaryExpr(expr, BitwiseXor)
+ case _ => None
+ }
+ }
+
+ private def convertUnaryExpr(
+ expr: GeneralScalarExpression,
+ catalystExprBuilder: Expression => Expression): Option[Expression] = {
+ expr.children match {
+ case Array(child) => toCatalyst(child).map(catalystExprBuilder)
+ case _ => None
+ }
+ }
+
+ private def convertBinaryExpr(
+ expr: GeneralScalarExpression,
+ catalystExprBuilder: (Expression, Expression) => Expression): Option[Expression] = {
+ expr.children match {
+ case Array(left, right) =>
+ for {
+ catalystLeft <- toCatalyst(left)
+ catalystRight <- toCatalyst(right)
+ } yield catalystExprBuilder(catalystLeft, catalystRight)
+ case _ => None
+ }
+ }
+
+ private def convertExpr(
+ expr: GeneralScalarExpression,
+ catalystExprBuilder: Seq[Expression] => Expression): Option[Expression] = {
+ val catalystChildren = expr.children.flatMap(toCatalyst).toImmutableArraySeq
+ if (expr.children.length == catalystChildren.length) {
+ Some(catalystExprBuilder(catalystChildren))
+ } else {
+ None
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 7789c23b50a48..015bd1e3e1428 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -277,14 +277,14 @@ private[aggregate] object CollectTopK {
@ExpressionDescription(
usage = """
_FUNC_(expr[, delimiter])[ WITHIN GROUP (ORDER BY key [ASC | DESC] [,...])] - Returns
- the concatenation of non-null input values, separated by the delimiter ordered by key.
- If all values are null, null is returned.
+ the concatenation of non-NULL input values, separated by the delimiter ordered by key.
+ If all values are NULL, NULL is returned.
""",
arguments = """
Arguments:
* expr - a string or binary expression to be concatenated.
* delimiter - an optional string or binary foldable expression used to separate the input values.
- If null, the concatenation will be performed without a delimiter. Default is null.
+ If NULL, the concatenation will be performed without a delimiter. Default is NULL.
* key - an optional expression for ordering the input values. Multiple keys can be specified.
If none are specified, the order of the rows in the result is non-deterministic.
""",
@@ -400,7 +400,7 @@ case class ListAgg(
)
)
} else if (delimiter.dataType == NullType) {
- // null is the default empty delimiter so type is not important
+ // Null is the default empty delimiter so type is not important
TypeCheckSuccess
} else {
TypeUtils.checkForSameTypeInputExpr(child.dataType :: delimiter.dataType :: Nil, prettyName)
@@ -451,7 +451,7 @@ case class ListAgg(
}
/**
- * @return ordering by (orderValue0, orderValue1, ...)
+ * @return Ordering by (orderValue0, orderValue1, ...)
* for InternalRow with format [childValue, orderValue0, orderValue1, ...]
*/
private[this] def bufferOrdering: Ordering[InternalRow] = {
@@ -477,7 +477,7 @@ case class ListAgg(
}
/**
- * @return delimiter value or default empty value if delimiter is null. Type respects [[dataType]]
+ * @return Delimiter value or default empty value if delimiter is null. Type respects [[dataType]]
*/
private[this] def getDelimiterValue: Either[UTF8String, Array[Byte]] = {
val delimiterValue = delimiter.eval()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index de74bb2f8cd21..2564d4eab9bd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1537,11 +1537,11 @@ object CodeGenerator extends Logging {
)
evaluator.setExtendedClass(classOf[GeneratedClass])
- logDebug({
+ logBasedOnLevel(SQLConf.get.codegenLogLevel) {
// Only add extra debugging info to byte code when we are going to print the source code.
evaluator.setDebuggingInformation(true, true, false)
- s"\n${CodeFormatter.format(code)}"
- })
+ log"\n${MDC(LogKeys.CODE, CodeFormatter.format(code))}"
+ }
val codeStats = try {
evaluator.cook("generated.java", code.body)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1fc8fe8f247b8..81484f8dd7da5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -457,6 +457,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with
case (name, expr) =>
val metadata = expr match {
case ne: NamedExpression => ne.metadata
+ case gsf: GetStructField => gsf.metadata
case _ => Metadata.empty
}
StructField(name.toString, expr.dataType, expr.nullable, metadata)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 2013cd8d6e636..804c80bd68b49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.QueryContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
-import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -90,7 +91,7 @@ object ExtractValue {
}
}
-trait ExtractValue extends Expression {
+trait ExtractValue extends Expression with QueryErrorsBase {
override def nullIntolerant: Boolean = true
final override val nodePatterns: Seq[TreePattern] = Seq(EXTRACT_VALUE)
val child: Expression
@@ -314,6 +315,30 @@ case class GetArrayItem(
})
}
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (_: ArrayType, e2) if !e2.isInstanceOf[IntegralType] =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(1),
+ "requiredType" -> toSQLType(IntegralType),
+ "inputSql" -> toSQLExpr(right),
+ "inputType" -> toSQLType(right.dataType))
+ )
+ case (e1, _) if !e1.isInstanceOf[ArrayType] =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(TypeCollection(ArrayType)),
+ "inputSql" -> toSQLExpr(left),
+ "inputType" -> toSQLType(left.dataType))
+ )
+ case _ => TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): GetArrayItem =
copy(child = newLeft, ordinal = newRight)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala
new file mode 100644
index 0000000000000..1a52723669358
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import java.util.UUID
+
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
+import org.apache.spark.sql.types.{DataType, StringType}
+
+trait TableConstraint {
+
+ /** Returns the user-provided name of the constraint */
+ def userProvidedName: String
+
+ /** Returns the name of the table containing this constraint */
+ def tableName: String
+
+ /** Returns the user-provided characteristics of the constraint (e.g., ENFORCED, RELY) */
+ def userProvidedCharacteristic: ConstraintCharacteristic
+
+ /** Creates a new constraint with the user-provided name
+ *
+ * @param name Constraint name
+ * @return New TableConstraint instance
+ */
+ def withUserProvidedName(name: String): TableConstraint
+
+ /**
+ * Creates a new constraint with the given table name
+ *
+ * @param tableName Name of the table containing this constraint
+ * @return New TableConstraint instance
+ */
+ def withTableName(tableName: String): TableConstraint
+
+ /** Creates a new constraint with the user-provided characteristic
+ *
+ * @param c Constraint characteristic (ENFORCED, RELY)
+ * @return New TableConstraint instance
+ */
+ def withUserProvidedCharacteristic(c: ConstraintCharacteristic): TableConstraint
+
+ // Generate a constraint name based on the table name if the name is not specified
+ protected def generateName(tableName: String): String
+
+ /**
+ * Gets the constraint name. If no name is provided by the user (null or empty),
+ * generates a name based on the table name using generateName.
+ *
+ * @return The constraint name (either user-provided or generated)
+ */
+ final def name: String = {
+ if (userProvidedName == null || userProvidedName.isEmpty) {
+ generateName(tableName)
+ } else {
+ userProvidedName
+ }
+ }
+
+ // This method generates a random identifier that has a similar format to Git commit hashes,
+ // which provide a good balance between uniqueness and readability when used as constraint
+ // identifiers.
+ final protected def randomSuffix: String = {
+ UUID.randomUUID().toString.replace("-", "").take(7)
+ }
+
+ protected def failIfEnforced(c: ConstraintCharacteristic, constraintType: String): Unit = {
+ if (c.enforced.contains(true)) {
+ val origin = CurrentOrigin.get
+ throw new ParseException(
+ command = origin.sqlText,
+ start = origin,
+ errorClass = "UNSUPPORTED_CONSTRAINT_CHARACTERISTIC",
+ messageParameters = Map(
+ "characteristic" -> "ENFORCED",
+ "constraintType" -> constraintType)
+ )
+ }
+ }
+}
+
+case class ConstraintCharacteristic(enforced: Option[Boolean], rely: Option[Boolean])
+
+object ConstraintCharacteristic {
+ val empty: ConstraintCharacteristic = ConstraintCharacteristic(None, None)
+}
+
+// scalastyle:off line.size.limit
+case class CheckConstraint(
+ child: Expression,
+ condition: String,
+ override val userProvidedName: String = null,
+ override val tableName: String = null,
+ override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
+ extends UnaryExpression
+ with Unevaluable
+ with TableConstraint {
+// scalastyle:on line.size.limit
+
+ override protected def withNewChildInternal(newChild: Expression): Expression =
+ copy(child = newChild)
+
+ override protected def generateName(tableName: String): String = {
+ s"${tableName}_chk_$randomSuffix"
+ }
+
+ override def sql: String = s"CONSTRAINT $userProvidedName CHECK ($condition)"
+
+ override def dataType: DataType = StringType
+
+ override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)
+
+ override def withTableName(tableName: String): TableConstraint = copy(tableName = tableName)
+
+ override def withUserProvidedCharacteristic(c: ConstraintCharacteristic): TableConstraint =
+ copy(userProvidedCharacteristic = c)
+}
+
+// scalastyle:off line.size.limit
+case class PrimaryKeyConstraint(
+ columns: Seq[String],
+ override val userProvidedName: String = null,
+ override val tableName: String = null,
+ override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
+ extends TableConstraint {
+// scalastyle:on line.size.limit
+
+ override protected def generateName(tableName: String): String = s"${tableName}_pk"
+
+ override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)
+
+ override def withTableName(tableName: String): TableConstraint = copy(tableName = tableName)
+
+ override def withUserProvidedCharacteristic(c: ConstraintCharacteristic): TableConstraint = {
+ failIfEnforced(c, "PRIMARY KEY")
+ copy(userProvidedCharacteristic = c)
+ }
+}
+
+// scalastyle:off line.size.limit
+case class UniqueConstraint(
+ columns: Seq[String],
+ override val userProvidedName: String = null,
+ override val tableName: String = null,
+ override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
+ extends TableConstraint {
+// scalastyle:on line.size.limit
+
+ override protected def generateName(tableName: String): String = {
+ s"${tableName}_uniq_$randomSuffix"
+ }
+
+ override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)
+
+ override def withTableName(tableName: String): TableConstraint = copy(tableName = tableName)
+
+ override def withUserProvidedCharacteristic(c: ConstraintCharacteristic): TableConstraint = {
+ failIfEnforced(c, "UNIQUE")
+ copy(userProvidedCharacteristic = c)
+ }
+}
+
+// scalastyle:off line.size.limit
+case class ForeignKeyConstraint(
+ childColumns: Seq[String] = Seq.empty,
+ parentTableId: Seq[String] = Seq.empty,
+ parentColumns: Seq[String] = Seq.empty,
+ override val userProvidedName: String = null,
+ override val tableName: String = null,
+ override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
+ extends TableConstraint {
+// scalastyle:on line.size.limit
+
+ override protected def generateName(tableName: String): String =
+ s"${tableName}_${parentTableId.last}_fk_$randomSuffix"
+
+ override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)
+
+ override def withTableName(tableName: String): TableConstraint = copy(tableName = tableName)
+
+ override def withUserProvidedCharacteristic(c: ConstraintCharacteristic): TableConstraint = {
+ failIfEnforced(c, "FOREIGN KEY")
+ copy(userProvidedCharacteristic = c)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
index fd298b33450b3..4d6c862a4fbce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.expressions.csv
import com.univocity.parsers.csv.CsvParser
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{DataSourceOptions, InternalRow}
import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, PermissiveMode}
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.{DataType, NullType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -66,6 +66,7 @@ case class CsvToStructsEvaluator(
if (mode != PermissiveMode && mode != FailFastMode) {
throw QueryCompilationErrors.parseModeUnsupportedError("from_csv", mode)
}
+ DataSourceOptions.validateSingleVariantColumn(parsedOptions.parameters, Some(nullableSchema))
ExprUtils.verifyColumnNameOfCorruptRecord(
nullableSchema,
parsedOptions.columnNameOfCorruptRecord)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 67d9aff947cfa..b4b81dade83fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -399,15 +400,6 @@ trait GetTimeField extends UnaryExpression
}
}
-@ExpressionDescription(
- usage = "_FUNC_(timestamp) - Returns the hour component of the string/timestamp.",
- examples = """
- Examples:
- > SELECT _FUNC_('2009-07-30 12:58:59');
- 12
- """,
- group = "datetime_funcs",
- since = "1.5.0")
case class Hour(child: Expression, timeZoneId: Option[String] = None) extends GetTimeField {
def this(child: Expression) = this(child, None)
override def withTimeZone(timeZoneId: String): Hour = copy(timeZoneId = Option(timeZoneId))
@@ -2104,11 +2096,13 @@ case class ParseToDate(
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends RuntimeReplaceable with ImplicitCastInputTypes with TimeZoneAwareExpression {
- override lazy val replacement: Expression = format.map { f =>
- Cast(GetTimestamp(left, f, TimestampType, "try_to_date", timeZoneId, ansiEnabled), DateType,
- timeZoneId, EvalMode.fromBoolean(ansiEnabled))
- }.getOrElse(Cast(left, DateType, timeZoneId,
- EvalMode.fromBoolean(ansiEnabled))) // backwards compatibility
+ override lazy val replacement: Expression = withOrigin(origin) {
+ format.map { f =>
+ Cast(GetTimestamp(left, f, TimestampType, "try_to_date", timeZoneId, ansiEnabled), DateType,
+ timeZoneId, EvalMode.fromBoolean(ansiEnabled))
+ }.getOrElse(Cast(left, DateType, timeZoneId,
+ EvalMode.fromBoolean(ansiEnabled))) // backwards compatibility
+ }
def this(left: Expression, format: Expression) = {
this(left, Option(format))
@@ -2183,9 +2177,11 @@ case class ParseToTimestamp(
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends RuntimeReplaceable with ImplicitCastInputTypes with TimeZoneAwareExpression {
- override lazy val replacement: Expression = format.map { f =>
- GetTimestamp(left, f, dataType, "try_to_timestamp", timeZoneId, failOnError = failOnError)
- }.getOrElse(Cast(left, dataType, timeZoneId, ansiEnabled = failOnError))
+ override lazy val replacement: Expression = withOrigin(origin) {
+ format.map { f =>
+ GetTimestamp(left, f, dataType, "try_to_timestamp", timeZoneId, failOnError = failOnError)
+ }.getOrElse(Cast(left, dataType, timeZoneId, ansiEnabled = failOnError))
+ }
def this(left: Expression, format: Expression) = {
this(left, Option(format), SQLConf.get.timestampType)
@@ -3097,7 +3093,7 @@ object DatePartExpressionBuilder extends ExpressionBuilder {
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(field FROM source) - Extracts a part of the date/timestamp or interval source.",
+ usage = "_FUNC_(field FROM source) - Extracts a part of the date or timestamp or time or interval source.",
arguments = """
Arguments:
* field - selects which part of the source should be extracted
@@ -3121,7 +3117,11 @@ object DatePartExpressionBuilder extends ExpressionBuilder {
- "HOUR", ("H", "HOURS", "HR", "HRS") - how many hours the `microseconds` contains
- "MINUTE", ("M", "MIN", "MINS", "MINUTES") - how many minutes left after taking hours from `microseconds`
- "SECOND", ("S", "SEC", "SECONDS", "SECS") - how many second with fractions left after taking hours and minutes from `microseconds`
- * source - a date/timestamp or interval column from where `field` should be extracted
+ - Supported string values of `field` for time (which consists of `hour`, `minute`, `second`) are(case insensitive):
+ - "HOUR", ("H", "HOURS", "HR", "HRS") - The hour field (0 - 23)
+ - "MINUTE", ("M", "MIN", "MINS", "MINUTES") - the minutes field (0 - 59)
+ - "SECOND", ("S", "SEC", "SECONDS", "SECS") - the seconds field, including fractional parts up to micro second precision. Returns a DECIMAL(8, 6) precision value.
+ * source - a date or timestamp or time or interval column from where `field` should be extracted
""",
examples = """
Examples:
@@ -3141,6 +3141,12 @@ object DatePartExpressionBuilder extends ExpressionBuilder {
11
> SELECT _FUNC_(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND);
55
+ > SELECT _FUNC_(HOUR FROM time '09:08:01.000001');
+ 9
+ > SELECT _FUNC_(MINUTE FROM time '09:08:01.000001');
+ 8
+ > SELECT _FUNC_(SECOND FROM time '09:08:01.000001');
+ 1.000001
""",
note = """
The _FUNC_ function is equivalent to `date_part(field, source)`.
@@ -3179,6 +3185,8 @@ object Extract {
source.dataType match {
case _: AnsiIntervalType | CalendarIntervalType =>
ExtractIntervalPart.parseExtractField(fieldStr.toString, source)
+ case _: TimeType =>
+ TimePart.parseExtractField(fieldStr.toString, source)
case _ =>
DatePart.parseExtractField(fieldStr.toString, source)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index ac493d19df1b5..7cb645e601d36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -256,6 +256,8 @@ case class Crc32(child: Expression)
* input with seed.
* - binary: use murmur3 to hash the bytes with seed.
* - string: get the bytes of string and hash it.
+ * - time: it stores long value of `microseconds` since the midnight, use
+ * murmur3 to hash the long input with seed.
* - array: The `result` starts with seed, then use `result` as seed, recursively
* calculate hash value for each element, and assign the element hash
* value to `result`.
@@ -507,7 +509,7 @@ abstract class HashExpression[E] extends Expression {
case NullType => ""
case BooleanType => genHashBoolean(input, result)
case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result)
- case LongType => genHashLong(input, result)
+ case LongType | _: TimeType => genHashLong(input, result)
case TimestampType | TimestampNTZType => genHashTimestamp(input, result)
case FloatType => genHashFloat(input, result)
case DoubleType => genHashDouble(input, result)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 5a34b21703e52..e3ed2c4a0b0b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffset}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period, ZoneOffset}
import java.util
import java.util.Objects
@@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localTimeToMicros}
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, periodToMonths, toDayTimeIntervalString, toYearMonthIntervalString}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
@@ -89,6 +89,7 @@ object Literal {
case l: LocalDateTime => Literal(DateTimeUtils.localDateTimeToMicros(l), TimestampNTZType)
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
+ case lt: LocalTime => Literal(localTimeToMicros(lt), TimeType())
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType())
case p: Period => Literal(periodToMonths(p), YearMonthIntervalType())
case a: Array[Byte] => Literal(a, BinaryType)
@@ -126,6 +127,7 @@ object Literal {
// java classes
case _ if clz == classOf[LocalDate] => DateType
case _ if clz == classOf[Date] => DateType
+ case _ if clz == classOf[LocalTime] => TimeType()
case _ if clz == classOf[Instant] => TimestampType
case _ if clz == classOf[Timestamp] => TimestampType
case _ if clz == classOf[LocalDateTime] => TimestampNTZType
@@ -198,6 +200,7 @@ object Literal {
case DateType => create(0, DateType)
case TimestampType => create(0L, TimestampType)
case TimestampNTZType => create(0L, TimestampNTZType)
+ case t: TimeType => create(0L, t)
case it: DayTimeIntervalType => create(0L, it)
case it: YearMonthIntervalType => create(0, it)
case CharType(length) =>
@@ -432,6 +435,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
dataType match {
case DateType =>
DateFormatter().format(value.asInstanceOf[Int])
+ case _: TimeType =>
+ new FractionTimeFormatter().format(value.asInstanceOf[Long])
case TimestampType =>
TimestampFormatter.getFractionFormatter(timeZoneId).format(value.asInstanceOf[Long])
case TimestampNTZType =>
@@ -478,7 +483,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
val jsonValue = (value, dataType) match {
case (null, _) => JNull
case (i: Int, DateType) => JString(toString)
- case (l: Long, TimestampType) => JString(toString)
+ case (l: Long, TimestampType | _: TimeType) => JString(toString)
case (other, _) => JString(other.toString)
}
("value" -> jsonValue) :: ("dataType" -> dataType.jsonValue) :: Nil
@@ -521,7 +526,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
}
case ByteType | ShortType =>
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
- case TimestampType | TimestampNTZType | LongType | _: DayTimeIntervalType =>
+ case TimestampType | TimestampNTZType | LongType | _: DayTimeIntervalType | _: TimeType =>
toExprCode(s"${value}L")
case _ =>
val constRef = ctx.addReferenceObj("literal", value, javaType)
@@ -562,6 +567,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
case (v: Decimal, t: DecimalType) => s"${v}BD"
case (v: Int, DateType) =>
s"DATE '$toString'"
+ case (_: Long, _: TimeType) =>
+ s"TIME '$toString'"
case (v: Long, TimestampType) =>
s"TIMESTAMP '$toString'"
case (v: Long, TimestampNTZType) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index cd5aedb9bb891..e7d3701544c54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
-import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator}
+import org.apache.spark.sql.catalyst.util.{toPrettySQL, MapData, RandomUUIDGenerator}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.internal.SQLConf
@@ -179,7 +179,7 @@ case class AssertTrue(left: Expression, right: Expression, replacement: Expressi
}
def this(left: Expression) = {
- this(left, Literal(s"'${left.simpleString(SQLConf.get.maxToStringFields)}' is not true!"))
+ this(left, Literal(s"'${toPrettySQL(left)}' is not true!"))
}
override def parameters: Seq[Expression] = Seq(left, right)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index d8d81a9cc12f8..c31c72bc11488 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,7 +23,7 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedPlanId}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
import org.apache.spark.sql.catalyst.expressions.Cast._
@@ -419,6 +419,13 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
copy(values = newChildren.dropRight(1), query = newChildren.last.asInstanceOf[ListQuery])
}
+case class UnresolvedInSubqueryPlanId(values: Seq[Expression], planId: Long)
+ extends UnresolvedPlanId {
+
+ override def withPlan(plan: LogicalPlan): Expression = {
+ InSubquery(values, ListQuery(plan))
+ }
+}
/**
* Evaluates to `true` if `list` contains `value`.
@@ -1010,6 +1017,11 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
// finitely enumerable. The allowable types are checked below by checkInputDataTypes.
override def inputType: AbstractDataType = AnyDataType
+ // For value comparison, the struct field name and nullability does not matter.
+ protected override def sameType(left: DataType, right: DataType): Boolean = {
+ DataType.equalsStructurally(left, right, ignoreNullability = true)
+ }
+
final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_COMPARISON)
override lazy val canonicalized: Expression = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala
new file mode 100644
index 0000000000000..0f9405f43634c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala
@@ -0,0 +1,550 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.time.DateTimeException
+import java.util.Locale
+
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType, toSQLValue}
+import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.TimeFormatter
+import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
+import org.apache.spark.sql.internal.types.StringTypeWithCollation
+import org.apache.spark.sql.types.{AbstractDataType, DataType, DecimalType, IntegerType, ObjectType, TimeType, TypeCollection}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Parses a column to a time based on the given format.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(str[, format]) - Parses the `str` expression with the `format` expression to a time.
+ If `format` is malformed or its application does not result in a well formed time, the function
+ raises an error. By default, it follows casting rules to a time if the `format` is omitted.
+ """,
+ arguments = """
+ Arguments:
+ * str - A string to be parsed to time.
+ * format - Time format pattern to follow. See Datetime Patterns for valid
+ time format patterns.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('00:12:00');
+ 00:12:00
+ > SELECT _FUNC_('12.10.05', 'HH.mm.ss');
+ 12:10:05
+ """,
+ group = "datetime_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ToTime(str: Expression, format: Option[Expression])
+ extends RuntimeReplaceable with ExpectsInputTypes {
+
+ def this(str: Expression, format: Expression) = this(str, Option(format))
+ def this(str: Expression) = this(str, None)
+
+ private def invokeParser(
+ fmt: Option[String] = None,
+ arguments: Seq[Expression] = children): Expression = {
+ Invoke(
+ targetObject = Literal.create(ToTimeParser(fmt), ObjectType(classOf[ToTimeParser])),
+ functionName = "parse",
+ dataType = TimeType(),
+ arguments = arguments,
+ methodInputTypes = arguments.map(_.dataType))
+ }
+
+ override lazy val replacement: Expression = format match {
+ case None => invokeParser()
+ case Some(expr) if expr.foldable =>
+ Option(expr.eval())
+ .map(f => invokeParser(Some(f.toString), Seq(str)))
+ .getOrElse(Literal(null, expr.dataType))
+ case _ => invokeParser()
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ Seq(StringTypeWithCollation(supportsTrimCollation = true)) ++
+ format.map(_ => StringTypeWithCollation(supportsTrimCollation = true))
+ }
+
+ override def prettyName: String = "to_time"
+
+ override def children: Seq[Expression] = str +: format.toSeq
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ if (format.isDefined) {
+ copy(str = newChildren.head, format = Some(newChildren.last))
+ } else {
+ copy(str = newChildren.head)
+ }
+ }
+}
+
+case class ToTimeParser(fmt: Option[String]) {
+ private lazy val formatter = TimeFormatter(fmt, isParsing = true)
+
+ def this() = this(None)
+
+ private def withErrorCondition(input: => UTF8String, fmt: => Option[String])
+ (f: => Long): Long = {
+ try f
+ catch {
+ case e: DateTimeException =>
+ throw QueryExecutionErrors.timeParseError(input.toString, fmt, e)
+ }
+ }
+
+ def parse(s: UTF8String): Long = withErrorCondition(s, fmt)(formatter.parse(s.toString))
+
+ def parse(s: UTF8String, fmt: UTF8String): Long = {
+ val format = fmt.toString
+ withErrorCondition(s, Some(format)) {
+ TimeFormatter(format, isParsing = true).parse(s.toString)
+ }
+ }
+}
+
+object TimePart {
+
+ def parseExtractField(extractField: String, source: Expression): Expression =
+ extractField.toUpperCase(Locale.ROOT) match {
+ case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => HoursOfTime(source)
+ case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => MinutesOfTime(source)
+ case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => SecondsOfTimeWithFraction(source)
+ case _ =>
+ throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(
+ extractField,
+ source)
+ }
+}
+
+/**
+ * * Parses a column to a time based on the supplied format.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(str[, format]) - Parses the `str` expression with the `format` expression to a time.
+ If `format` is malformed or its application does not result in a well formed time, the function
+ returns NULL. By default, it follows casting rules to a time if the `format` is omitted.
+ """,
+ arguments = """
+ Arguments:
+ * str - A string to be parsed to time.
+ * format - Time format pattern to follow. See Datetime Patterns for valid
+ time format patterns.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('00:12:00.001');
+ 00:12:00.001
+ > SELECT _FUNC_('12.10.05.999999', 'HH.mm.ss.SSSSSS');
+ 12:10:05.999999
+ > SELECT _FUNC_('foo', 'HH:mm:ss');
+ NULL
+ """,
+ group = "datetime_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+object TryToTimeExpressionBuilder extends ExpressionBuilder {
+ override def build(funcName: String, expressions: Seq[Expression]): Expression = {
+ val numArgs = expressions.length
+ if (numArgs == 1 || numArgs == 2) {
+ TryEval(ToTime(expressions.head, expressions.drop(1).lastOption))
+ } else {
+ throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), numArgs)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(time_expr) - Returns the minute component of the given time.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(TIME'23:59:59.999999');
+ 59
+ """,
+ since = "4.1.0",
+ group = "datetime_funcs")
+// scalastyle:on line.size.limit
+case class MinutesOfTime(child: Expression)
+ extends RuntimeReplaceable
+ with ExpectsInputTypes {
+
+ override def replacement: Expression = StaticInvoke(
+ classOf[DateTimeUtils.type],
+ IntegerType,
+ "getMinutesOfTime",
+ Seq(child),
+ Seq(child.dataType)
+ )
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(TimeType.MIN_PRECISION to TimeType.MAX_PRECISION map TimeType.apply: _*))
+
+ override def children: Seq[Expression] = Seq(child)
+
+ override def prettyName: String = "minute"
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ copy(child = newChildren.head)
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns the minute component of the given expression.
+
+ If `expr` is a TIMESTAMP or a string that can be cast to timestamp,
+ it returns the minute of that timestamp.
+ If `expr` is a TIME type (since 4.1.0), it returns the minute of the time-of-day.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('2009-07-30 12:58:59');
+ 58
+ > SELECT _FUNC_(TIME'23:59:59.999999');
+ 59
+ """,
+ since = "1.5.0",
+ group = "datetime_funcs")
+// scalastyle:on line.size.limit
+object MinuteExpressionBuilder extends ExpressionBuilder {
+ override def build(name: String, expressions: Seq[Expression]): Expression = {
+ if (expressions.isEmpty) {
+ throw QueryCompilationErrors.wrongNumArgsError(name, Seq("> 0"), expressions.length)
+ } else {
+ val child = expressions.head
+ child.dataType match {
+ case _: TimeType =>
+ MinutesOfTime(child)
+ case _ =>
+ Minute(child)
+ }
+ }
+ }
+}
+
+case class HoursOfTime(child: Expression)
+ extends RuntimeReplaceable
+ with ExpectsInputTypes {
+
+ override def replacement: Expression = StaticInvoke(
+ classOf[DateTimeUtils.type],
+ IntegerType,
+ "getHoursOfTime",
+ Seq(child),
+ Seq(child.dataType)
+ )
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(TimeType.MIN_PRECISION to TimeType.MAX_PRECISION map TimeType.apply: _*))
+
+ override def children: Seq[Expression] = Seq(child)
+
+ override def prettyName: String = "hour"
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ copy(child = newChildren.head)
+ }
+}
+
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns the hour component of the given expression.
+
+ If `expr` is a TIMESTAMP or a string that can be cast to timestamp,
+ it returns the hour of that timestamp.
+ If `expr` is a TIME type (since 4.1.0), it returns the hour of the time-of-day.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('2018-02-14 12:58:59');
+ 12
+ > SELECT _FUNC_(TIME'13:59:59.999999');
+ 13
+ """,
+ since = "1.5.0",
+ group = "datetime_funcs")
+object HourExpressionBuilder extends ExpressionBuilder {
+ override def build(name: String, expressions: Seq[Expression]): Expression = {
+ if (expressions.isEmpty) {
+ throw QueryCompilationErrors.wrongNumArgsError(name, Seq("> 0"), expressions.length)
+ } else {
+ val child = expressions.head
+ child.dataType match {
+ case _: TimeType =>
+ HoursOfTime(child)
+ case _ =>
+ Hour(child)
+ }
+ }
+ }
+}
+
+case class SecondsOfTimeWithFraction(child: Expression)
+ extends RuntimeReplaceable
+ with ExpectsInputTypes {
+
+ override def replacement: Expression = {
+
+ StaticInvoke(
+ classOf[DateTimeUtils.type],
+ DecimalType(8, 6),
+ "getSecondsOfTimeWithFraction",
+ Seq(child, Literal(precision)),
+ Seq(child.dataType, IntegerType))
+ }
+ private val precision: Int = child.dataType.asInstanceOf[TimeType].precision
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TimeType(precision))
+
+ override def children: Seq[Expression] = Seq(child)
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ copy(child = newChildren.head)
+ }
+}
+
+case class SecondsOfTime(child: Expression)
+ extends RuntimeReplaceable
+ with ExpectsInputTypes {
+
+ override def replacement: Expression = StaticInvoke(
+ classOf[DateTimeUtils.type],
+ IntegerType,
+ "getSecondsOfTime",
+ Seq(child),
+ Seq(child.dataType)
+ )
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(TimeType.MIN_PRECISION to TimeType.MAX_PRECISION map TimeType.apply: _*))
+
+ override def children: Seq[Expression] = Seq(child)
+
+ override def prettyName: String = "second"
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ copy(child = newChildren.head)
+ }
+}
+
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns the second component of the given expression.
+
+ If `expr` is a TIMESTAMP or a string that can be cast to timestamp,
+ it returns the second of that timestamp.
+ If `expr` is a TIME type (since 4.1.0), it returns the second of the time-of-day.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('2018-02-14 12:58:59');
+ 59
+ > SELECT _FUNC_(TIME'13:25:59.999999');
+ 59
+ """,
+ since = "1.5.0",
+ group = "datetime_funcs")
+object SecondExpressionBuilder extends ExpressionBuilder {
+ override def build(name: String, expressions: Seq[Expression]): Expression = {
+ if (expressions.isEmpty) {
+ throw QueryCompilationErrors.wrongNumArgsError(name, Seq("> 0"), expressions.length)
+ } else {
+ val child = expressions.head
+ child.dataType match {
+ case _: TimeType =>
+ SecondsOfTime(child)
+ case _ =>
+ Second(child)
+ }
+ }
+ }
+}
+
+/**
+ * Returns the current time at the start of query evaluation.
+ * There is no code generation since this expression should get constant folded by the optimizer.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_([precision]) - Returns the current time at the start of query evaluation.
+ All calls of current_time within the same query return the same value.
+
+ _FUNC_ - Returns the current time at the start of query evaluation.
+ """,
+ arguments = """
+ Arguments:
+ * precision - An optional integer literal in the range [0..6], indicating how many
+ fractional digits of seconds to include. If omitted, the default is 6.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_();
+ 15:49:11.914120
+ > SELECT _FUNC_;
+ 15:49:11.914120
+ > SELECT _FUNC_(0);
+ 15:49:11
+ > SELECT _FUNC_(3);
+ 15:49:11.914
+ > SELECT _FUNC_(1+1);
+ 15:49:11.91
+ """,
+ group = "datetime_funcs",
+ since = "4.1.0"
+)
+case class CurrentTime(child: Expression = Literal(TimeType.MICROS_PRECISION))
+ extends UnaryExpression with FoldableUnevaluable with ImplicitCastInputTypes {
+
+ def this() = {
+ this(Literal(TimeType.MICROS_PRECISION))
+ }
+
+ final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
+
+ override def nullable: Boolean = false
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ // Check foldability
+ if (!child.foldable) {
+ return DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("precision"),
+ "inputType" -> toSQLType(child.dataType),
+ "inputExpr" -> toSQLExpr(child)
+ )
+ )
+ }
+
+ // Evaluate
+ val precisionValue = child.eval()
+ if (precisionValue == null) {
+ return DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_NULL",
+ messageParameters = Map("exprName" -> "precision"))
+ }
+
+ // Check numeric range
+ precisionValue match {
+ case n: Number =>
+ val p = n.intValue()
+ if (p < TimeType.MIN_PRECISION || p > TimeType.MICROS_PRECISION) {
+ return DataTypeMismatch(
+ errorSubClass = "VALUE_OUT_OF_RANGE",
+ messageParameters = Map(
+ "exprName" -> toSQLId("precision"),
+ "valueRange" -> s"[${TimeType.MIN_PRECISION}, ${TimeType.MICROS_PRECISION}]",
+ "currentValue" -> toSQLValue(p, IntegerType)
+ )
+ )
+ }
+ case _ =>
+ return DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(IntegerType),
+ "inputSql" -> toSQLExpr(child),
+ "inputType" -> toSQLType(child.dataType))
+ )
+ }
+ TypeCheckSuccess
+ }
+
+ // Because checkInputDataTypes ensures the argument is foldable & valid,
+ // we can directly evaluate here.
+ lazy val precision: Int = child.eval().asInstanceOf[Number].intValue()
+
+ override def dataType: DataType = TimeType(precision)
+
+ override def prettyName: String = "current_time"
+
+ override protected def withNewChildInternal(newChild: Expression): Expression = {
+ copy(child = newChild)
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(hour, minute, second) - Create time from hour, minute and second fields. For invalid inputs it will throw an error.",
+ arguments = """
+ Arguments:
+ * hour - the hour to represent, from 0 to 23
+ * minute - the minute to represent, from 0 to 59
+ * second - the second to represent, from 0 to 59.999999
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(6, 30, 45.887);
+ 06:30:45.887
+ > SELECT _FUNC_(NULL, 30, 0);
+ NULL
+ """,
+ group = "datetime_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class MakeTime(
+ hours: Expression,
+ minutes: Expression,
+ secsAndMicros: Expression)
+ extends RuntimeReplaceable
+ with ImplicitCastInputTypes
+ with ExpectsInputTypes {
+
+ // Accept `sec` as DecimalType to avoid loosing precision of microseconds while converting
+ // it to the fractional part of `sec`. If `sec` is an IntegerType, it can be cast into decimal
+ // safely because we use DecimalType(16, 6) which is wider than DecimalType(10, 0).
+ override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, DecimalType(16, 6))
+ override def children: Seq[Expression] = Seq(hours, minutes, secsAndMicros)
+ override def prettyName: String = "make_time"
+
+ override def replacement: Expression = StaticInvoke(
+ classOf[DateTimeUtils.type],
+ TimeType(TimeType.MICROS_PRECISION),
+ "timeToMicros",
+ children,
+ inputTypes
+ )
+
+ override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): MakeTime =
+ copy(hours = newChildren(0), minutes = newChildren(1), secsAndMicros = newChildren(2))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index fcd760561f909..5831a29c00a19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -203,11 +203,13 @@ object VariantPathParser extends RegexParsers {
ArrayExtraction(index.toInt)
}
+ override def skipWhitespace: Boolean = false
+
// Parse key segment like `.name`, `['name']`, or `["name"]`.
private def key: Parser[VariantPathSegment] =
for {
- key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
- "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+ key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^']*".r <~ "']" |
+ "[\"" ~> """[^"]*""".r <~ "\"]"
} yield {
ObjectExtraction(key)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
index 89d7b8d9421a7..438110d7acc4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
@@ -17,10 +17,15 @@
package org.apache.spark.sql.catalyst.expressions.xml
-import org.apache.spark.sql.catalyst.util.GenericArrayData
-import org.apache.spark.sql.catalyst.xml.XmlInferSchema
+import java.io.CharArrayWriter
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils}
+import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, GenericArrayData, PermissiveMode}
+import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions}
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
object XmlExpressionEvalUtils {
@@ -119,3 +124,88 @@ case class XPathListEvaluator(path: UTF8String) extends XPathEvaluator {
}
}
}
+
+case class XmlToStructsEvaluator(
+ options: Map[String, String],
+ nullableSchema: DataType,
+ nameOfCorruptRecord: String,
+ timeZoneId: Option[String],
+ child: Expression
+) {
+ @transient lazy val parsedOptions = new XmlOptions(options, timeZoneId.get, nameOfCorruptRecord)
+
+ // This converts parsed rows to the desired output by the given schema.
+ @transient
+ private lazy val converter =
+ (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
+
+ // Parser that parse XML strings as internal rows
+ @transient
+ private lazy val parser = {
+ val mode = parsedOptions.parseMode
+ if (mode != PermissiveMode && mode != FailFastMode) {
+ throw QueryCompilationErrors.parseModeUnsupportedError("from_xml", mode)
+ }
+
+ // The parser is only used when the input schema is StructType
+ val schema = nullableSchema.asInstanceOf[StructType]
+ ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+ val rawParser = new StaxXmlParser(schema, parsedOptions)
+
+ val xsdSchema = Option(parsedOptions.rowValidationXSDPath).map(ValidatorUtil.getSchema)
+
+ new FailureSafeParser[String](
+ input => rawParser.doParseColumn(input, mode, xsdSchema),
+ mode,
+ schema,
+ parsedOptions.columnNameOfCorruptRecord)
+ }
+
+ final def evaluate(xml: UTF8String): Any = {
+ if (xml == null) return null
+ nullableSchema match {
+ case _: VariantType => StaxXmlParser.parseVariant(xml.toString, parsedOptions)
+ case _: StructType => converter(parser.parse(xml.toString))
+ }
+ }
+}
+
+case class StructsToXmlEvaluator(
+ options: Map[String, String],
+ inputSchema: DataType,
+ timeZoneId: Option[String]) {
+
+ @transient
+ lazy val writer = new CharArrayWriter()
+
+ @transient
+ lazy val gen =
+ new StaxXmlGenerator(inputSchema, writer, new XmlOptions(options, timeZoneId.get), false)
+
+ // This converts rows to the XML output according to the given schema.
+ @transient
+ lazy val converter: Any => UTF8String = {
+ def getAndReset(): UTF8String = {
+ gen.flush()
+ val xmlString = writer.toString
+ writer.reset()
+ UTF8String.fromString(xmlString)
+ }
+
+ inputSchema match {
+ case _: StructType =>
+ (row: Any) =>
+ gen.write(row.asInstanceOf[InternalRow])
+ getAndReset()
+ case _: VariantType =>
+ (v: Any) =>
+ gen.write(v.asInstanceOf[VariantVal])
+ getAndReset()
+ }
+ }
+
+ final def evaluate(value: Any): Any = {
+ if (value == null) return null
+ converter(value)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
index 25a054f79c368..b6c495bca5604 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
@@ -16,17 +16,14 @@
*/
package org.apache.spark.sql.catalyst.expressions
-import java.io.CharArrayWriter
-
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
-import org.apache.spark.sql.catalyst.expressions.xml.XmlExpressionEvalUtils
-import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, FailureSafeParser, PermissiveMode}
+import org.apache.spark.sql.catalyst.expressions.xml.{StructsToXmlEvaluator, XmlExpressionEvalUtils, XmlToStructsEvaluator}
+import org.apache.spark.sql.catalyst.util.DropMalformedMode
import org.apache.spark.sql.catalyst.util.TypeUtils._
-import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions}
+import org.apache.spark.sql.catalyst.xml.{XmlInferSchema, XmlOptions}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
@@ -54,7 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String
since = "4.0.0")
// scalastyle:on line.size.limit
case class XmlToStructs(
- schema: StructType,
+ schema: DataType,
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
@@ -65,7 +62,7 @@ case class XmlToStructs(
def this(child: Expression, schema: Expression, options: Map[String, String]) =
this(
- schema = ExprUtils.evalSchemaExpr(schema),
+ schema = ExprUtils.evalTypeExpr(schema),
options = options,
child = child,
timeZoneId = None)
@@ -81,45 +78,34 @@ case class XmlToStructs(
def this(child: Expression, schema: Expression, options: Expression) =
this(
- schema = ExprUtils.evalSchemaExpr(schema),
+ schema = ExprUtils.evalTypeExpr(schema),
options = ExprUtils.convertToMapData(options),
child = child,
timeZoneId = None)
- // This converts parsed rows to the desired output by the given schema.
+ override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
+ case _: StructType | _: VariantType =>
+ val checkResult = ExprUtils.checkXmlSchema(nullableSchema)
+ if (checkResult.isFailure) checkResult else super.checkInputDataTypes()
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "INVALID_XML_SCHEMA",
+ messageParameters = Map("schema" -> toSQLType(nullableSchema)))
+ }
+
@transient
- private lazy val converter =
- (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
+ private lazy val evaluator: XmlToStructsEvaluator =
+ XmlToStructsEvaluator(options, nullableSchema, nameOfCorruptRecord, timeZoneId, child)
private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
- @transient
- private lazy val parser = {
- val parsedOptions = new XmlOptions(options, timeZoneId.get, nameOfCorruptRecord)
- val mode = parsedOptions.parseMode
- if (mode != PermissiveMode && mode != FailFastMode) {
- throw QueryCompilationErrors.parseModeUnsupportedError("from_xml", mode)
- }
- ExprUtils.verifyColumnNameOfCorruptRecord(
- nullableSchema, parsedOptions.columnNameOfCorruptRecord)
- val rawParser = new StaxXmlParser(schema, parsedOptions)
- val xsdSchema = Option(parsedOptions.rowValidationXSDPath).map(ValidatorUtil.getSchema)
-
- new FailureSafeParser[String](
- input => rawParser.doParseColumn(input, mode, xsdSchema),
- mode,
- nullableSchema,
- parsedOptions.columnNameOfCorruptRecord)
- }
-
override def dataType: DataType = nullableSchema
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}
- override def nullSafeEval(xml: Any): Any =
- converter(parser.parse(xml.asInstanceOf[UTF8String].toString))
+ override def nullSafeEval(xml: Any): Any = evaluator.evaluate(xml.asInstanceOf[UTF8String])
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val expr = ctx.addReferenceObj("this", this)
@@ -258,6 +244,7 @@ case class StructsToXml(
override def checkInputDataTypes(): TypeCheckResult = {
child.dataType match {
case _: StructType => TypeCheckSuccess
+ case _: VariantType => TypeCheckSuccess
case _ => DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
@@ -271,33 +258,12 @@ case class StructsToXml(
}
@transient
- lazy val writer = new CharArrayWriter()
-
- @transient
- lazy val inputSchema: StructType = child.dataType.asInstanceOf[StructType]
-
- @transient
- lazy val gen = new StaxXmlGenerator(
- inputSchema, writer, new XmlOptions(options, timeZoneId.get), false)
-
- // This converts rows to the XML output according to the given schema.
- @transient
- lazy val converter: Any => UTF8String = {
- def getAndReset(): UTF8String = {
- gen.flush()
- val xmlString = writer.toString
- writer.reset()
- UTF8String.fromString(xmlString)
- }
- (row: Any) =>
- gen.write(row.asInstanceOf[InternalRow])
- getAndReset()
- }
+ private lazy val evaluator = StructsToXmlEvaluator(options, child.dataType, timeZoneId)
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
- override def nullSafeEval(value: Any): Any = converter(value)
+ override def nullSafeEval(value: Any): Any = evaluator.evaluate(value)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val expr = ctx.addReferenceObj("this", this)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 0e1dfdf366a89..6ba152d309846 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -91,7 +91,7 @@ class JSONOptions(
val parseMode: ParseMode =
parameters.get(MODE).map(ParseMode.fromString).getOrElse(PermissiveMode)
val columnNameOfCorruptRecord =
- parameters.getOrElse(COLUMN_NAME_OF_CORRUPTED_RECORD, defaultColumnNameOfCorruptRecord)
+ parameters.getOrElse(COLUMN_NAME_OF_CORRUPT_RECORD, defaultColumnNameOfCorruptRecord)
// Whether to ignore column of all null values or empty array/struct during schema inference
val dropFieldIfAllNull = parameters.get(DROP_FIELD_IF_ALL_NULL).map(_.toBoolean).getOrElse(false)
@@ -284,10 +284,10 @@ object JSONOptions extends DataSourceOptions {
val LINE_SEP = newOption("lineSep")
val PRETTY = newOption("pretty")
val INFER_TIMESTAMP = newOption("inferTimestamp")
- val COLUMN_NAME_OF_CORRUPTED_RECORD = newOption("columnNameOfCorruptRecord")
+ val COLUMN_NAME_OF_CORRUPT_RECORD = newOption(DataSourceOptions.COLUMN_NAME_OF_CORRUPT_RECORD)
val TIME_ZONE = newOption("timeZone")
val WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT = newOption("writeNonAsciiCharacterAsCodePoint")
- val SINGLE_VARIANT_COLUMN = newOption("singleVariantColumn")
+ val SINGLE_VARIANT_COLUMN = newOption(DataSourceOptions.SINGLE_VARIANT_COLUMN)
val USE_UNSAFE_ROW = newOption("useUnsafeRow")
// Options with alternative
val ENCODING = "encoding"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
index 2d1e71a63a8ce..eb95bef537391 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
@@ -366,8 +366,10 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate ||
newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && {
val Seq(newPlanSupportsObjectHashAggregate, cachedPlanSupportsObjectHashAggregate) =
- aggregateExpressionsSeq.map(aggregateExpressions =>
- Aggregate.supportsObjectHashAggregate(aggregateExpressions))
+ aggregateExpressionsSeq.zip(groupByExpressionSeq).map {
+ case (aggregateExpressions, groupByExpressions) =>
+ Aggregate.supportsObjectHashAggregate(aggregateExpressions, groupByExpressions)
+ }
newPlanSupportsObjectHashAggregate && cachedPlanSupportsObjectHashAggregate ||
newPlanSupportsObjectHashAggregate == cachedPlanSupportsObjectHashAggregate
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9d269f37e58b9..7b437c302b145 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -100,7 +100,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
Seq(
// Operator push down
PushProjectionThroughUnion,
- PushProjectionThroughLimit,
+ PushProjectionThroughLimitAndOffset,
ReorderJoin,
EliminateOuterJoin,
PushDownPredicates,
@@ -671,8 +671,10 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
val subQueryAttributes = if (conf.getConf(SQLConf
.EXCLUDE_SUBQUERY_EXP_REFS_FROM_REMOVE_REDUNDANT_ALIASES)) {
// Collect the references for all the subquery expressions in the plan.
- AttributeSet.fromAttributeSets(plan.expressions.collect {
- case e: SubqueryExpression => e.references
+ AttributeSet.fromAttributeSets(plan.expressions.flatMap { e =>
+ e.collect {
+ case s: SubqueryExpression => s.references
+ }
})
} else {
AttributeSet.empty
@@ -848,6 +850,13 @@ object LimitPushDown extends Rule[LogicalPlan] {
case LocalLimit(exp, u: Union) =>
LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _))))
+ // If limit node is present, we should propagate it down to UnionLoop, so that it is later
+ // propagated to UnionLoopExec.
+ case LocalLimit(IntegerLiteral(limit), p @ Project(_, ul: UnionLoop)) =>
+ p.copy(child = ul.copy(limit = Some(limit)))
+ case LocalLimit(IntegerLiteral(limit), ul: UnionLoop) =>
+ ul.copy(limit = Some(limit))
+
// Add extra limits below JOIN:
// 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides
// respectively if join condition is not empty.
@@ -1032,6 +1041,10 @@ object ColumnPruning extends Rule[LogicalPlan] {
p
}
+ // TODO: Pruning `UnionLoop`s needs to take into account both the outer `Project` and the inner
+ // `UnionLoopRef` nodes.
+ case p @ Project(_, _: UnionLoop) => p
+
// Prune unnecessary window expressions
case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) =>
val windowExprs = w.windowExpressions.filter(p.references.contains)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimit.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffset.scala
similarity index 75%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimit.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffset.scala
index 6280cc5e42c9f..e329251c36083 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimit.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffset.scala
@@ -17,16 +17,16 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LocalLimit, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LocalLimit, LogicalPlan, Offset, Project}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{LIMIT, PROJECT}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LIMIT, OFFSET, PROJECT}
/**
* Pushes Project operator through Limit operator.
*/
-object PushProjectionThroughLimit extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
- _.containsAllPatterns(PROJECT, LIMIT)) {
+object PushProjectionThroughLimitAndOffset extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(treeBits =>
+ treeBits.containsPattern(PROJECT) && treeBits.containsAnyPattern(OFFSET, LIMIT)) {
case p @ Project(projectList, limit @ LocalLimit(_, child))
if projectList.forall(_.deterministic) =>
@@ -35,5 +35,9 @@ object PushProjectionThroughLimit extends Rule[LogicalPlan] {
case p @ Project(projectList, g @ GlobalLimit(_, limit @ LocalLimit(_, child)))
if projectList.forall(_.deterministic) =>
g.copy(child = limit.copy(child = p.copy(projectList, child)))
+
+ case p @ Project(projectList, offset @ Offset(_, child))
+ if projectList.forall(_.deterministic) =>
+ offset.copy(child = p.copy(projectList, child))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index e867953bcf282..b3236bbfa3755 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, TreeNodeTag}
import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1091,6 +1092,8 @@ object FoldablePropagation extends Rule[LogicalPlan] {
object SimplifyCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
_.containsPattern(CAST), ruleId) {
+ case c @ Cast(e: NamedExpression, StringType, _, _)
+ if e.dataType == StringType && e.metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY) => c
case Cast(e, dataType, _, _) if e.dataType == dataType => e
case c @ Cast(Cast(e, dt1: NumericType, _, _), dt2: NumericType, _, _)
if isWiderCast(e.dataType, dt1) && isWiderCast(dt1, dt2) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
index 0fbfce5962c73..21e09f2e56d19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.trees.TreePatternBits
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros}
+import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.{instantToMicrosOfDay, truncateTimeMicrosToPrecision}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._
@@ -113,6 +114,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
val instant = Instant.now()
val currentTimestampMicros = instantToMicros(instant)
val currentTime = Literal.create(currentTimestampMicros, TimestampType)
+ val currentTimeOfDayMicros = instantToMicrosOfDay(instant, conf.sessionLocalTimeZone)
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)
val currentDates = collection.mutable.HashMap.empty[ZoneId, Literal]
val localTimestamps = collection.mutable.HashMap.empty[ZoneId, Literal]
@@ -129,6 +131,10 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
Literal.create(
DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
})
+ case currentTimeType : CurrentTime =>
+ val truncatedTime = truncateTimeMicrosToPrecision(currentTimeOfDayMicros,
+ currentTimeType.precision)
+ Literal.create(truncatedTime, TimeType(currentTimeType.precision))
case CurrentTimestamp() | Now() => currentTime
case CurrentTimeZone() => timezone
case localTimestamp: LocalTimestamp =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 5fb30e810649b..9413bd7b454d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -293,19 +293,23 @@ trait JoinSelectionHelper extends Logging {
join: Join,
hintOnly: Boolean,
conf: SQLConf): Option[BuildSide] = {
- val buildLeft = if (hintOnly) {
- hintToBroadcastLeft(join.hint)
- } else {
- canBroadcastBySize(join.left, conf) && !hintToNotBroadcastLeft(join.hint)
+ def shouldBuildLeft(): Boolean = {
+ if (hintOnly) {
+ hintToBroadcastLeft(join.hint)
+ } else {
+ canBroadcastBySize(join.left, conf) && !hintToNotBroadcastLeft(join.hint)
+ }
}
- val buildRight = if (hintOnly) {
- hintToBroadcastRight(join.hint)
- } else {
- canBroadcastBySize(join.right, conf) && !hintToNotBroadcastRight(join.hint)
+ def shouldBuildRight(): Boolean = {
+ if (hintOnly) {
+ hintToBroadcastRight(join.hint)
+ } else {
+ canBroadcastBySize(join.right, conf) && !hintToNotBroadcastRight(join.hint)
+ }
}
getBuildSide(
- canBuildBroadcastLeft(join.joinType) && buildLeft,
- canBuildBroadcastRight(join.joinType) && buildRight,
+ canBuildBroadcastLeft(join.joinType) && shouldBuildLeft(),
+ canBuildBroadcastRight(join.joinType) && shouldBuildRight(),
join.left,
join.right
)
@@ -315,25 +319,29 @@ trait JoinSelectionHelper extends Logging {
join: Join,
hintOnly: Boolean,
conf: SQLConf): Option[BuildSide] = {
- val buildLeft = if (hintOnly) {
- hintToShuffleHashJoinLeft(join.hint)
- } else {
- hintToPreferShuffleHashJoinLeft(join.hint) ||
- (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.left, conf) &&
- muchSmaller(join.left, join.right, conf)) ||
- forceApplyShuffledHashJoin(conf)
+ def shouldBuildLeft(): Boolean = {
+ if (hintOnly) {
+ hintToShuffleHashJoinLeft(join.hint)
+ } else {
+ hintToPreferShuffleHashJoinLeft(join.hint) ||
+ (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.left, conf) &&
+ muchSmaller(join.left, join.right, conf)) ||
+ forceApplyShuffledHashJoin(conf)
+ }
}
- val buildRight = if (hintOnly) {
- hintToShuffleHashJoinRight(join.hint)
- } else {
- hintToPreferShuffleHashJoinRight(join.hint) ||
- (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.right, conf) &&
- muchSmaller(join.right, join.left, conf)) ||
- forceApplyShuffledHashJoin(conf)
+ def shouldBuildRight(): Boolean = {
+ if (hintOnly) {
+ hintToShuffleHashJoinRight(join.hint)
+ } else {
+ hintToPreferShuffleHashJoinRight(join.hint) ||
+ (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(join.right, conf) &&
+ muchSmaller(join.right, join.left, conf)) ||
+ forceApplyShuffledHashJoin(conf)
+ }
}
getBuildSide(
- canBuildShuffledHashJoinLeft(join.joinType) && buildLeft,
- canBuildShuffledHashJoinRight(join.joinType) && buildRight,
+ canBuildShuffledHashJoinLeft(join.joinType) && shouldBuildLeft(),
+ canBuildShuffledHashJoinRight(join.joinType) && shouldBuildRight(),
join.left,
join.right
)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala
index c17409a68c963..216136d8a7c82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CompoundPlanStatement, Logic
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
/**
* Base class for all ANTLR4 [[ParserInterface]] implementations.
@@ -102,6 +103,13 @@ abstract class AbstractSqlParser extends AbstractParser with ParserInterface {
}
}
+ override def parseRoutineParam(sqlText: String): StructType = parse(sqlText) { parser =>
+ val ctx = parser.singleRoutineParamList()
+ withErrorHandling(ctx, Some(sqlText)) {
+ astBuilder.visitSingleRoutineParamList(ctx)
+ }
+ }
+
def withErrorHandling[T](ctx: ParserRuleContext, sqlText: Option[String])(toResult: => T): T = {
withOrigin(ctx, sqlText) {
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 523b7c88fc8ce..b15c5b17332ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.parser
-import java.util.Locale
+import java.util.{List, Locale}
import java.util.concurrent.TimeUnit
import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Set}
@@ -25,7 +25,6 @@ import scala.jdk.CollectionConverters._
import scala.util.{Left, Right}
import org.antlr.v4.runtime.{ParserRuleContext, RuleContext, Token}
-import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable, SparkThrowableHelper}
@@ -44,7 +43,7 @@ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, CollationFactory, DateTimeUtils, IntervalUtils, SparkParserUtils}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTime, stringToTimestamp, stringToTimestampWithoutTimeZone}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
@@ -117,6 +116,21 @@ class AstBuilder extends DataTypeAstBuilder
}
}
+ /**
+ * Retrieves the original input text for a given parser context, preserving all whitespace and
+ * formatting.
+ *
+ * ANTLR's default getText method removes whitespace because lexer rules typically skip it.
+ * This utility method extracts the exact text from the original input stream, using token
+ * indices.
+ *
+ * @param ctx The parser context to retrieve original text from.
+ * @return The original input text, including all whitespaces and formatting.
+ */
+ private def getOriginalText(ctx: ParserRuleContext): String = {
+ SparkParserUtils.source(ctx)
+ }
+
/**
* Override the default behavior for all visit methods. This will only return a non-null result
* when the context has only one child. This is done because there is no generic method to
@@ -813,12 +827,15 @@ class AstBuilder extends DataTypeAstBuilder
}
/**
- * Parameters used for writing query to a table:
- * (table ident, options, tableColumnList, partitionKeys, ifPartitionNotExists, byName).
+ * Parameters used for writing query to a table.
*/
- type InsertTableParams =
- (IdentifierReferenceContext, Option[OptionsClauseContext], Seq[String],
- Map[String, Option[String]], Boolean, Boolean)
+ case class InsertTableParams(
+ relationCtx: IdentifierReferenceContext,
+ options: Option[OptionsClauseContext],
+ userSpecifiedCols: Seq[String],
+ partitionSpec: Map[String, Option[String]],
+ ifPartitionNotExists: Boolean,
+ byName: Boolean)
/**
* Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
@@ -845,31 +862,36 @@ class AstBuilder extends DataTypeAstBuilder
// 2. Write commands do not hold the table logical plan as a child, and we need to add
// additional resolution code to resolve identifiers inside the write commands.
case table: InsertIntoTableContext =>
- val (relationCtx, options, cols, partition, ifPartitionNotExists, byName)
- = visitInsertIntoTable(table)
- withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
+ val insertParams = visitInsertIntoTable(table)
+ withIdentClause(insertParams.relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
- createUnresolvedRelation(relationCtx, ident, options, Seq(TableWritePrivilege.INSERT)),
- partition,
- cols,
- otherPlans.head,
+ table = createUnresolvedRelation(
+ ctx = insertParams.relationCtx,
+ ident = ident,
+ optionsClause = insertParams.options,
+ writePrivileges = Seq(TableWritePrivilege.INSERT)),
+ partitionSpec = insertParams.partitionSpec,
+ userSpecifiedCols = insertParams.userSpecifiedCols,
+ query = otherPlans.head,
overwrite = false,
- ifPartitionNotExists,
- byName)
+ ifPartitionNotExists = insertParams.ifPartitionNotExists,
+ byName = insertParams.byName)
})
case table: InsertOverwriteTableContext =>
- val (relationCtx, options, cols, partition, ifPartitionNotExists, byName)
- = visitInsertOverwriteTable(table)
- withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
+ val insertParams = visitInsertOverwriteTable(table)
+ withIdentClause(insertParams.relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
- createUnresolvedRelation(relationCtx, ident, options,
- Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
- partition,
- cols,
- otherPlans.head,
+ table = createUnresolvedRelation(
+ ctx = insertParams.relationCtx,
+ ident = ident,
+ optionsClause = insertParams.options,
+ writePrivileges = Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
+ partitionSpec = insertParams.partitionSpec,
+ userSpecifiedCols = insertParams.userSpecifiedCols,
+ query = otherPlans.head,
overwrite = true,
- ifPartitionNotExists,
- byName)
+ ifPartitionNotExists = insertParams.ifPartitionNotExists,
+ byName = insertParams.byName)
})
case ctx: InsertIntoReplaceWhereContext =>
val options = Option(ctx.optionsClause())
@@ -896,8 +918,9 @@ class AstBuilder extends DataTypeAstBuilder
*/
override def visitInsertIntoTable(
ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
- val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
- val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+ val userSpecifiedCols =
+ Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
+ val partitionSpec = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
blockBang(ctx.errorCapturingNot())
@@ -905,8 +928,13 @@ class AstBuilder extends DataTypeAstBuilder
invalidStatement("INSERT INTO ... IF NOT EXISTS", ctx)
}
- (ctx.identifierReference, Option(ctx.optionsClause()), cols, partitionKeys, false,
- ctx.NAME() != null)
+ InsertTableParams(
+ relationCtx = ctx.identifierReference(),
+ options = Option(ctx.optionsClause()),
+ userSpecifiedCols = userSpecifiedCols,
+ partitionSpec = partitionSpec,
+ ifPartitionNotExists = false,
+ byName = ctx.NAME() != null)
}
/**
@@ -915,19 +943,25 @@ class AstBuilder extends DataTypeAstBuilder
override def visitInsertOverwriteTable(
ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
assert(ctx.OVERWRITE() != null)
- val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
- val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+ val userSpecifiedCols =
+ Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
+ val partitionSpec = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
blockBang(ctx.errorCapturingNot())
- val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
+ val dynamicPartitionKeys: Map[String, Option[String]] = partitionSpec.filter(_._2.isEmpty)
if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
operationNotAllowed("IF NOT EXISTS with dynamic partitions: " +
dynamicPartitionKeys.keys.mkString(", "), ctx)
}
- (ctx.identifierReference, Option(ctx.optionsClause()), cols, partitionKeys,
- ctx.EXISTS() != null, ctx.NAME() != null)
+ InsertTableParams(
+ relationCtx = ctx.identifierReference,
+ options = Option(ctx.optionsClause()),
+ userSpecifiedCols = userSpecifiedCols,
+ partitionSpec = partitionSpec,
+ ifPartitionNotExists = ctx.EXISTS() != null,
+ byName = ctx.NAME() != null)
}
/**
@@ -1249,17 +1283,17 @@ class AstBuilder extends DataTypeAstBuilder
val withOrder = if (
!order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
clause = PipeOperators.orderByClause
- Sort(order.asScala.map(visitSortItem).toSeq, global = true, query)
+ Sort(order.asScala.map(visitSortItemAndReplaceOrdinals).toSeq, global = true, query)
} else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
clause = PipeOperators.sortByClause
- Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query)
+ Sort(sort.asScala.map(visitSortItemAndReplaceOrdinals).toSeq, global = false, query)
} else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
clause = PipeOperators.distributeByClause
withRepartitionByExpression(ctx, expressionList(distributeBy), query)
} else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
clause = PipeOperators.sortByDistributeByClause
Sort(
- sort.asScala.map(visitSortItem).toSeq,
+ sort.asScala.map(visitSortItemAndReplaceOrdinals).toSeq,
global = false,
withRepartitionByExpression(ctx, expressionList(distributeBy), query))
} else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
@@ -1517,11 +1551,16 @@ class AstBuilder extends DataTypeAstBuilder
val newProjectList: Seq[NamedExpression] = if (isPipeOperatorSelect) {
// If this is a pipe operator |> SELECT clause, add a [[PipeExpression]] wrapping
// each alias in the project list, so the analyzer can check invariants later.
+ def withPipeExpression(node: UnaryExpression): NamedExpression = {
+ node.withNewChildren(Seq(
+ PipeExpression(node.child, isAggregate = false, PipeOperators.selectClause)))
+ .asInstanceOf[NamedExpression]
+ }
namedExpressions.map {
case a: Alias =>
- a.withNewChildren(Seq(
- PipeExpression(a.child, isAggregate = false, PipeOperators.selectClause)))
- .asInstanceOf[NamedExpression]
+ withPipeExpression(a)
+ case u: UnresolvedAlias =>
+ withPipeExpression(u)
case other =>
other
}
@@ -1775,7 +1814,6 @@ class AstBuilder extends DataTypeAstBuilder
throw new ParseException(
command = Some(SparkParserUtils.command(n)),
start = Origin(),
- stop = Origin(),
errorClass = "PARSE_SYNTAX_ERROR",
messageParameters = Map(
"error" -> s"'$error'",
@@ -1784,24 +1822,27 @@ class AstBuilder extends DataTypeAstBuilder
}
visitNamedExpression(n)
}.toSeq
+ val groupByExpressionsWithOrdinals =
+ replaceOrdinalsInGroupingExpressions(groupByExpressions)
if (ctx.GROUPING != null) {
// GROUP BY ... GROUPING SETS (...)
// `groupByExpressions` can be non-empty for Hive compatibility. It may add extra grouping
// expressions that do not exist in GROUPING SETS (...), and the value is always null.
// For example, `SELECT a, b, c FROM ... GROUP BY a, b, c GROUPING SETS (a, b)`, the output
// of column `c` is always null.
- val groupingSets =
- ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq)
- Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)),
- selectExpressions, query)
+ val groupingSetsWithOrdinals = visitGroupingSetAndReplaceOrdinals(ctx.groupingSet)
+ Aggregate(
+ Seq(GroupingSets(groupingSetsWithOrdinals, groupByExpressionsWithOrdinals)),
+ selectExpressions, query
+ )
} else {
// GROUP BY .... (WITH CUBE | WITH ROLLUP)?
val mappedGroupByExpressions = if (ctx.CUBE != null) {
- Seq(Cube(groupByExpressions.map(Seq(_))))
+ Seq(Cube(groupByExpressionsWithOrdinals.map(Seq(_))))
} else if (ctx.ROLLUP != null) {
- Seq(Rollup(groupByExpressions.map(Seq(_))))
+ Seq(Rollup(groupByExpressionsWithOrdinals.map(Seq(_))))
} else {
- groupByExpressions
+ groupByExpressionsWithOrdinals
}
Aggregate(mappedGroupByExpressions, selectExpressions, query)
}
@@ -1815,8 +1856,12 @@ class AstBuilder extends DataTypeAstBuilder
} else {
expression(groupByExpr.expression)
}
- })
- Aggregate(groupByExpressions.toSeq, selectExpressions, query)
+ }).toSeq
+ Aggregate(
+ groupingExpressions = replaceOrdinalsInGroupingExpressions(groupByExpressions),
+ aggregateExpressions = selectExpressions,
+ child = query
+ )
}
}
@@ -1824,7 +1869,7 @@ class AstBuilder extends DataTypeAstBuilder
groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = {
val groupingSets = groupingAnalytics.groupingSet.asScala
.map(_.expression.asScala.map(e => expression(e)).toSeq)
- if (groupingAnalytics.CUBE != null) {
+ val baseGroupingSet = if (groupingAnalytics.CUBE != null) {
// CUBE(A, B, (A, B), ()) is not supported.
if (groupingSets.exists(_.isEmpty)) {
throw QueryParsingErrors.invalidGroupingSetError("CUBE", groupingAnalytics)
@@ -1848,6 +1893,9 @@ class AstBuilder extends DataTypeAstBuilder
}
GroupingSets(groupingSets.toSeq)
}
+ baseGroupingSet.withNewChildren(
+ newChildren = replaceOrdinalsInGroupingExpressions(baseGroupingSet.children)
+ ).asInstanceOf[BaseGroupingSets]
}
/**
@@ -2847,12 +2895,14 @@ class AstBuilder extends DataTypeAstBuilder
CurrentDate()
case SqlBaseParser.CURRENT_TIMESTAMP =>
CurrentTimestamp()
+ case SqlBaseParser.CURRENT_TIME =>
+ CurrentTime()
case SqlBaseParser.CURRENT_USER | SqlBaseParser.USER | SqlBaseParser.SESSION_USER =>
CurrentUser()
}
} else {
// If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there
- // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP`.
+ // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP` or `CURRENT_TIME`
UnresolvedAttribute.quoted(ctx.name.getText)
}
}
@@ -3133,7 +3183,9 @@ class AstBuilder extends DataTypeAstBuilder
override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) {
def value: Expression = {
val e = expression(ctx.expression)
- validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx)
+ validate(
+ e.resolved && e.foldable || e.isInstanceOf[Parameter],
+ "Frame bound value must be a literal.", ctx)
e
}
@@ -3327,6 +3379,7 @@ class AstBuilder extends DataTypeAstBuilder
val zoneId = getZoneId(conf.sessionLocalTimeZone)
val specialDate = convertSpecialDate(value, zoneId).map(Literal(_, DateType))
specialDate.getOrElse(toLiteral(stringToDate, DateType))
+ case TIME => toLiteral(stringToTime, TimeType())
case TIMESTAMP_NTZ =>
convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone))
.map(Literal(_, TimestampNTZType))
@@ -3841,27 +3894,52 @@ class AstBuilder extends DataTypeAstBuilder
* DataType parsing
* ******************************************************************************************** */
+ override def visitSingleRoutineParamList(
+ ctx: SingleRoutineParamListContext): StructType = withOrigin(ctx) {
+ val (cols, constraints) = visitColDefinitionList(ctx.colDefinitionList())
+ // Constraints and generated columns should have been rejected by the parser.
+ assert(constraints.isEmpty)
+ for (col <- cols) {
+ assert(col.generationExpression.isEmpty)
+ assert(col.identityColumnSpec.isEmpty)
+ }
+ // Build fields from the columns, converting comments and default values
+ val fields = for (col <- cols) yield {
+ val metadataBuilder = new MetadataBuilder().withMetadata(col.metadata)
+ col.comment.foreach { c =>
+ metadataBuilder.putString("comment", c)
+ }
+ col.defaultValue.foreach { default =>
+ metadataBuilder.putString(
+ StructType.SQL_FUNCTION_DEFAULT_METADATA_KEY, default.originalSQL)
+ }
+ StructField(col.name, col.dataType, col.nullable, metadataBuilder.build())
+ }
+ StructType(fields.toArray)
+ }
+
/**
* Create top level table schema.
*/
- protected def createSchema(ctx: ColDefinitionListContext): StructType = {
- val columns = Option(ctx).toArray.flatMap(visitColDefinitionList)
- StructType(columns.map(_.toV1Column))
+ protected def createSchema(ctx: TableElementListContext): StructType = {
+ val (cols, _) = visitTableElementList(ctx)
+ StructType(cols.map(_.toV1Column))
}
/**
* Get CREATE TABLE column definitions.
*/
override def visitColDefinitionList(
- ctx: ColDefinitionListContext): Seq[ColumnDefinition] = withOrigin(ctx) {
- ctx.colDefinition().asScala.map(visitColDefinition).toSeq
+ ctx: ColDefinitionListContext): TableElementList = withOrigin(ctx) {
+ val (colDefs, constraints) = ctx.colDefinition().asScala.map(visitColDefinition).toSeq.unzip
+ (colDefs, constraints.flatten)
}
/**
* Get a CREATE TABLE column definition.
*/
override def visitColDefinition(
- ctx: ColDefinitionContext): ColumnDefinition = withOrigin(ctx) {
+ ctx: ColDefinitionContext): ColumnAndConstraint = withOrigin(ctx) {
import ctx._
val name: String = colName.getText
@@ -3870,6 +3948,7 @@ class AstBuilder extends DataTypeAstBuilder
var defaultExpression: Option[DefaultExpressionContext] = None
var generationExpression: Option[GenerationExpressionContext] = None
var commentSpec: Option[CommentSpecContext] = None
+ var columnConstraint: Option[ColumnConstraintDefinitionContext] = None
ctx.colDefinitionOption().asScala.foreach { option =>
if (option.NULL != null) {
blockBang(option.errorCapturingNot)
@@ -3903,10 +3982,17 @@ class AstBuilder extends DataTypeAstBuilder
}
commentSpec = Some(spec)
}
+ Option(option.columnConstraintDefinition()).foreach { definition =>
+ if (columnConstraint.isDefined) {
+ throw QueryParsingErrors.duplicateTableColumnDescriptor(
+ option, name, "CONSTRAINT")
+ }
+ columnConstraint = Some(definition)
+ }
}
val dataType = typedVisit[DataType](ctx.dataType)
- ColumnDefinition(
+ val columnDef = ColumnDefinition(
name = name,
dataType = dataType,
nullable = nullable,
@@ -3919,8 +4005,61 @@ class AstBuilder extends DataTypeAstBuilder
case ctx: IdentityColumnContext => visitIdentityColumn(ctx, dataType)
}
)
+ val constraint = columnConstraint.map(c => visitColumnConstraintDefinition(name, c))
+ (columnDef, constraint)
+ }
+
+ private def visitColumnConstraintDefinition(
+ columnName: String,
+ ctx: ColumnConstraintDefinitionContext): TableConstraint = {
+ withOrigin(ctx) {
+ val name = if (ctx.name != null) {
+ ctx.name.getText
+ } else {
+ null
+ }
+ val constraintCharacteristic =
+ visitConstraintCharacteristics(ctx.constraintCharacteristic().asScala.toSeq)
+ val expr = visitColumnConstraint(columnName, ctx.columnConstraint())
+
+ expr.withUserProvidedName(name).withUserProvidedCharacteristic(constraintCharacteristic)
+ }
}
+ private def visitColumnConstraint(
+ columnName: String,
+ ctx: ColumnConstraintContext): TableConstraint = withOrigin(ctx) {
+ val columns = Seq(columnName)
+ if (ctx.checkConstraint() != null) {
+ visitCheckConstraint(ctx.checkConstraint())
+ } else if (ctx.uniqueSpec() != null) {
+ visitUniqueSpec(ctx.uniqueSpec(), columns)
+ } else {
+ assert(ctx.referenceSpec() != null)
+ val (tableId, refColumns) = visitReferenceSpec(ctx.referenceSpec())
+ ForeignKeyConstraint(
+ childColumns = columns,
+ parentTableId = tableId,
+ parentColumns = refColumns)
+ }
+ }
+
+ private def visitUniqueSpec(ctx: UniqueSpecContext, columns: Seq[String]): TableConstraint =
+ withOrigin(ctx) {
+ if (ctx.UNIQUE() != null) {
+ UniqueConstraint(columns)
+ } else {
+ PrimaryKeyConstraint(columns)
+ }
+ }
+
+ override def visitReferenceSpec(ctx: ReferenceSpecContext): (Seq[String], Seq[String]) =
+ withOrigin(ctx) {
+ val tableId = visitMultipartIdentifier(ctx.multipartIdentifier())
+ val refColumns = visitIdentifierList(ctx.parentColumns)
+ (tableId, refColumns)
+ }
+
/**
* Create a location string.
*/
@@ -3943,14 +4082,7 @@ class AstBuilder extends DataTypeAstBuilder
if (expr.containsPattern(PARAMETER)) {
throw QueryParsingErrors.parameterMarkerNotAllowed(place, expr.origin)
}
- // Extract the raw expression text so that we can save the user provided text. We don't
- // use `Expression.sql` to avoid storing incorrect text caused by bugs in any expression's
- // `sql` method. Note: `exprCtx.getText` returns a string without spaces, so we need to
- // get the text from the underlying char stream instead.
- val start = exprCtx.getStart.getStartIndex
- val end = exprCtx.getStop.getStopIndex
- val originalSQL = exprCtx.getStart.getInputStream.getText(new Interval(start, end))
- DefaultValueExpression(expr, originalSQL)
+ DefaultValueExpression(expr, getOriginalText(exprCtx))
}
/**
@@ -4141,6 +4273,10 @@ class AstBuilder extends DataTypeAstBuilder
Seq[Transform], Seq[ColumnDefinition], Option[BucketSpec], Map[String, String], OptionList,
Option[String], Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec])
+ type ColumnAndConstraint = (ColumnDefinition, Option[TableConstraint])
+
+ type TableElementList = (Seq[ColumnDefinition], Seq[TableConstraint])
+
/**
* Validate a create table statement and return the [[TableIdentifier]].
*/
@@ -4380,16 +4516,6 @@ class AstBuilder extends DataTypeAstBuilder
}
}
- /**
- * Create a [[ShowNamespaces]] command.
- */
- override def visitShowNamespaces(ctx: ShowNamespacesContext): LogicalPlan = withOrigin(ctx) {
- val multiPart = Option(ctx.multipartIdentifier).map(visitMultipartIdentifier)
- ShowNamespaces(
- UnresolvedNamespace(multiPart.getOrElse(Seq.empty[String])),
- Option(ctx.pattern).map(x => string(visitStringLit(x))))
- }
-
/**
* Create a [[DescribeNamespace]].
*
@@ -4681,6 +4807,40 @@ class AstBuilder extends DataTypeAstBuilder
}
}
+ override def visitTableElementList(ctx: TableElementListContext): TableElementList = {
+ if (ctx == null) {
+ return (Nil, Nil)
+ }
+ withOrigin(ctx) {
+ val columnDefs = new ArrayBuffer[ColumnDefinition]()
+ val constraints = new ArrayBuffer[TableConstraint]()
+
+ ctx.tableElement().asScala.foreach { element =>
+ if (element.tableConstraintDefinition() != null) {
+ constraints += visitTableConstraintDefinition(element.tableConstraintDefinition())
+ } else {
+ val (colDef, constraintOpt) = visitColDefinition(element.colDefinition())
+ columnDefs += colDef
+ constraintOpt.foreach(constraints += _)
+ }
+ }
+
+ // check if there are multiple primary keys
+ val primaryKeys = constraints.filter(_.isInstanceOf[PrimaryKeyConstraint])
+ if (primaryKeys.size > 1) {
+ val primaryKeyColumns =
+ primaryKeys
+ .map(_.asInstanceOf[PrimaryKeyConstraint]
+ .columns
+ .mkString("(", ", ", ")"))
+ .mkString(", ")
+ throw QueryParsingErrors.multiplePrimaryKeysError(ctx, primaryKeyColumns)
+ }
+
+ (columnDefs.toSeq, constraints.toSeq)
+ }
+ }
+
/**
* Create a table, returning a [[CreateTable]] or [[CreateTableAsSelect]] logical plan.
*
@@ -4715,10 +4875,12 @@ class AstBuilder extends DataTypeAstBuilder
val (identifierContext, temp, ifNotExists, external) =
visitCreateTableHeader(ctx.createTableHeader)
- val columns = Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse(Nil)
+ val (columns, constraints) = visitTableElementList(ctx.tableElementList())
+
val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText)
val (partTransforms, partCols, bucketSpec, properties, options, location, comment,
- collation, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses())
+ collation, serdeInfo, clusterBySpec) =
+ visitCreateTableClauses(ctx.createTableClauses())
if (provider.isDefined && serdeInfo.isDefined) {
invalidStatement(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx)
@@ -4735,33 +4897,37 @@ class AstBuilder extends DataTypeAstBuilder
bucketSpec.map(_.asTransform) ++
clusterBySpec.map(_.asTransform)
- val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment,
- collation, serdeInfo, external)
-
- Option(ctx.query).map(plan) match {
- case Some(_) if columns.nonEmpty =>
- operationNotAllowed(
- "Schema may not be specified in a Create Table As Select (CTAS) statement",
- ctx)
+ val asSelectPlan = Option(ctx.query).map(plan).toSeq
+ withIdentClause(identifierContext, asSelectPlan, (identifiers, otherPlans) => {
+ val namedConstraints =
+ constraints.map(c => c.withTableName(identifiers.last))
+ val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment,
+ collation, serdeInfo, external, namedConstraints)
+ val identifier = withOrigin(identifierContext) {
+ UnresolvedIdentifier(identifiers)
+ }
+ otherPlans.headOption match {
+ case Some(_) if columns.nonEmpty =>
+ operationNotAllowed(
+ "Schema may not be specified in a Create Table As Select (CTAS) statement",
+ ctx)
- case Some(_) if partCols.nonEmpty =>
- // non-reference partition columns are not allowed because schema can't be specified
- operationNotAllowed(
- "Partition column types may not be specified in Create Table As Select (CTAS)",
- ctx)
+ case Some(_) if partCols.nonEmpty =>
+ // non-reference partition columns are not allowed because schema can't be specified
+ operationNotAllowed(
+ "Partition column types may not be specified in Create Table As Select (CTAS)",
+ ctx)
- case Some(query) =>
- CreateTableAsSelect(withIdentClause(identifierContext, UnresolvedIdentifier(_)),
- partitioning, query, tableSpec, Map.empty, ifNotExists)
+ case Some(query) =>
+ CreateTableAsSelect(identifier, partitioning, query, tableSpec, Map.empty, ifNotExists)
- case _ =>
- // Note: table schema includes both the table columns list and the partition columns
- // with data type.
- val allColumns = columns ++ partCols
- CreateTable(
- withIdentClause(identifierContext, UnresolvedIdentifier(_)),
- allColumns, partitioning, tableSpec, ignoreIfExists = ifNotExists)
- }
+ case _ =>
+ // Note: table schema includes both the table columns list and the partition columns
+ // with data type.
+ val allColumns = columns ++ partCols
+ CreateTable(identifier, allColumns, partitioning, tableSpec, ignoreIfExists = ifNotExists)
+ }
+ })
}
/**
@@ -4797,7 +4963,7 @@ class AstBuilder extends DataTypeAstBuilder
val orCreate = ctx.replaceTableHeader().CREATE() != null
val (partTransforms, partCols, bucketSpec, properties, options, location, comment, collation,
serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses())
- val columns = Option(ctx.colDefinitionList()).map(visitColDefinitionList).getOrElse(Nil)
+ val (columns, constraints) = visitTableElementList(ctx.tableElementList())
val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText)
if (provider.isDefined && serdeInfo.isDefined) {
@@ -4809,34 +4975,39 @@ class AstBuilder extends DataTypeAstBuilder
bucketSpec.map(_.asTransform) ++
clusterBySpec.map(_.asTransform)
- val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment,
- collation, serdeInfo, external = false)
-
- Option(ctx.query).map(plan) match {
- case Some(_) if columns.nonEmpty =>
- operationNotAllowed(
- "Schema may not be specified in a Replace Table As Select (RTAS) statement",
- ctx)
+ val identifierContext = ctx.replaceTableHeader().identifierReference()
+ val asSelectPlan = Option(ctx.query).map(plan).toSeq
+ withIdentClause(identifierContext, asSelectPlan, (identifiers, otherPlans) => {
+ val namedConstraints =
+ constraints.map(c => c.withTableName(identifiers.last))
+ val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment,
+ collation, serdeInfo, external = false, namedConstraints)
+ val identifier = withOrigin(identifierContext) {
+ UnresolvedIdentifier(identifiers)
+ }
+ otherPlans.headOption match {
+ case Some(_) if columns.nonEmpty =>
+ operationNotAllowed(
+ "Schema may not be specified in a Replace Table As Select (RTAS) statement",
+ ctx)
- case Some(_) if partCols.nonEmpty =>
- // non-reference partition columns are not allowed because schema can't be specified
- operationNotAllowed(
- "Partition column types may not be specified in Replace Table As Select (RTAS)",
- ctx)
+ case Some(_) if partCols.nonEmpty =>
+ // non-reference partition columns are not allowed because schema can't be specified
+ operationNotAllowed(
+ "Partition column types may not be specified in Replace Table As Select (RTAS)",
+ ctx)
- case Some(query) =>
- ReplaceTableAsSelect(
- withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)),
- partitioning, query, tableSpec, writeOptions = Map.empty, orCreate = orCreate)
+ case Some(query) =>
+ ReplaceTableAsSelect(identifier, partitioning, query, tableSpec,
+ writeOptions = Map.empty, orCreate = orCreate)
- case _ =>
- // Note: table schema includes both the table columns list and the partition columns
- // with data type.
- val allColumns = columns ++ partCols
- ReplaceTable(
- withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)),
- allColumns, partitioning, tableSpec, orCreate = orCreate)
- }
+ case _ =>
+ // Note: table schema includes both the table columns list and the partition columns
+ // with data type.
+ val allColumns = columns ++ partCols
+ ReplaceTable(identifier, allColumns, partitioning, tableSpec, orCreate = orCreate)
+ }
+ })
}
/**
@@ -5238,6 +5409,112 @@ class AstBuilder extends DataTypeAstBuilder
AlterTableCollation(table, visitCollationSpec(ctx.collationSpec()))
}
+ override def visitTableConstraintDefinition(
+ ctx: TableConstraintDefinitionContext): TableConstraint =
+ withOrigin(ctx) {
+ val name = if (ctx.name != null) {
+ ctx.name.getText
+ } else {
+ null
+ }
+ val constraintCharacteristic =
+ visitConstraintCharacteristics(ctx.constraintCharacteristic().asScala.toSeq)
+ val expr =
+ visitTableConstraint(ctx.tableConstraint()).asInstanceOf[TableConstraint]
+
+ expr.withUserProvidedName(name).withUserProvidedCharacteristic(constraintCharacteristic)
+ }
+
+ override def visitCheckConstraint(ctx: CheckConstraintContext): CheckConstraint =
+ withOrigin(ctx) {
+ val condition = getOriginalText(ctx.expr)
+ CheckConstraint(
+ child = expression(ctx.booleanExpression()),
+ condition = condition)
+ }
+
+
+ override def visitUniqueConstraint(ctx: UniqueConstraintContext): TableConstraint =
+ withOrigin(ctx) {
+ val columns = visitIdentifierList(ctx.identifierList())
+ visitUniqueSpec(ctx.uniqueSpec(), columns)
+ }
+
+ override def visitForeignKeyConstraint(ctx: ForeignKeyConstraintContext): TableConstraint =
+ withOrigin(ctx) {
+ val columns = visitIdentifierList(ctx.identifierList())
+ val (parentTableId, parentColumns) = visitReferenceSpec(ctx.referenceSpec())
+ ForeignKeyConstraint(
+ childColumns = columns,
+ parentTableId = parentTableId,
+ parentColumns = parentColumns)
+ }
+
+ private def visitConstraintCharacteristics(
+ constraintCharacteristics: Seq[ConstraintCharacteristicContext]): ConstraintCharacteristic = {
+ var enforcement: Option[String] = None
+ var rely: Option[String] = None
+ constraintCharacteristics.foreach {
+ case e if e.enforcedCharacteristic() != null =>
+ val text = getOriginalText(e.enforcedCharacteristic()).toUpperCase(Locale.ROOT)
+ if (enforcement.isDefined) {
+ val invalidCharacteristics = s"${enforcement.get}, $text"
+ throw QueryParsingErrors.invalidConstraintCharacteristics(
+ e.enforcedCharacteristic(), invalidCharacteristics)
+ } else {
+ enforcement = Some(text)
+ }
+
+ case r if r.relyCharacteristic() != null =>
+ val text = r.relyCharacteristic().getText.toUpperCase(Locale.ROOT)
+ if (rely.isDefined) {
+ val invalidCharacteristics = s"${rely.get}, $text"
+ throw QueryParsingErrors.invalidConstraintCharacteristics(
+ r.relyCharacteristic(), invalidCharacteristics)
+ } else {
+ rely = Some(text)
+ }
+ }
+ ConstraintCharacteristic(enforcement.map(_ == "ENFORCED"), rely.map(_ == "RELY"))
+ }
+
+ /**
+ * Parse an [[AddConstraint]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table1 CONSTRAINT constraint_name CHECK (a > 0)
+ * }}}
+ */
+ override def visitAddTableConstraint(ctx: AddTableConstraintContext): LogicalPlan =
+ withOrigin(ctx) {
+ val tableConstraint = visitTableConstraintDefinition(ctx.tableConstraintDefinition())
+ withIdentClause(ctx.identifierReference, identifiers => {
+ val table = UnresolvedTable(identifiers, "ALTER TABLE ... ADD CONSTRAINT")
+ val namedConstraint = tableConstraint.withTableName(identifiers.last)
+ AddConstraint(table, namedConstraint)
+ })
+ }
+
+ /**
+ * Parse a [[DropConstraint]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table1 DROP CONSTRAINT constraint_name
+ * }}}
+ */
+ override def visitDropTableConstraint(ctx: DropTableConstraintContext): LogicalPlan =
+ withOrigin(ctx) {
+ val table = createUnresolvedTable(
+ ctx.identifierReference, "ALTER TABLE ... DROP CONSTRAINT")
+ DropConstraint(
+ table,
+ ctx.name.getText,
+ ifExists = ctx.EXISTS() != null,
+ cascade = ctx.CASCADE() != null)
+ }
+
/**
* Parse [[SetViewProperties]] or [[SetTableProperties]] commands.
*
@@ -6281,12 +6558,12 @@ class AstBuilder extends DataTypeAstBuilder
case n: NamedExpression =>
newGroupingExpressions += n
newAggregateExpressions += n
- // If the grouping expression is an integer literal, create [[UnresolvedOrdinal]] and
- // [[UnresolvedPipeAggregateOrdinal]] expressions to represent it in the final grouping
- // and aggregate expressions, respectively. This will let the
+ // If the grouping expression is an [[UnresolvedOrdinal]], replace the ordinal value and
+ // create [[UnresolvedPipeAggregateOrdinal]] expressions to represent it in the final
+ // grouping and aggregate expressions, respectively. This will let the
// [[ResolveOrdinalInOrderByAndGroupBy]] rule detect the ordinal in the aggregate list
// and replace it with the corresponding attribute from the child operator.
- case Literal(v: Int, IntegerType) if conf.groupByOrdinal =>
+ case UnresolvedOrdinal(v: Int) =>
newGroupingExpressions += UnresolvedOrdinal(newAggregateExpressions.length + 1)
newAggregateExpressions += UnresolvedAlias(UnresolvedPipeAggregateOrdinal(v), None)
case e: Expression =>
@@ -6307,6 +6584,58 @@ class AstBuilder extends DataTypeAstBuilder
}
}
+ /**
+ * Visits [[SortItemContext]] and replaces top-level [[Literal]]s with [[UnresolvedOrdinal]] in
+ * resulting expression, if `orderByOrdinal` is enabled.
+ */
+ private def visitSortItemAndReplaceOrdinals(sortItemContext: SortItemContext) = {
+ val visitedSortItem = visitSortItem(sortItemContext)
+ visitedSortItem.withNewChildren(
+ newChildren = Seq(replaceIntegerLiteralWithOrdinal(
+ expression = visitedSortItem.child,
+ canReplaceWithOrdinal = conf.orderByOrdinal
+ ))
+ ).asInstanceOf[SortOrder]
+ }
+
+ /**
+ * Replaces top-level integer [[Literal]]s with [[UnresolvedOrdinal]] in grouping expressions, if
+ * `groupByOrdinal` is enabled.
+ */
+ private def replaceOrdinalsInGroupingExpressions(groupingExpressions: Seq[Expression]) =
+ groupingExpressions.map(groupByExpression =>
+ replaceIntegerLiteralWithOrdinal(
+ expression = groupByExpression,
+ canReplaceWithOrdinal = conf.groupByOrdinal
+ )
+ ).toSeq
+
+ /**
+ * Visits grouping expressions in a [[GroupingSetContext]] and replaces top-level integer
+ * [[Literal]]s with [[UnresolvedOrdinal]]s in resulting expressions, if `groupByOrdinal` is
+ * enabled.
+ */
+ private def visitGroupingSetAndReplaceOrdinals(groupingSet: List[GroupingSetContext]) = {
+ groupingSet.asScala.map(_.expression.asScala.map(e => {
+ val visitedExpression = expression(e)
+ replaceIntegerLiteralWithOrdinal(
+ expression = visitedExpression,
+ canReplaceWithOrdinal = conf.groupByOrdinal
+ )
+ }).toSeq).toSeq
+ }
+
+ /**
+ * Replaces integer [[Literal]] with [[UnresolvedOrdinal]] if `canReplaceWithOrdinal` is true.
+ */
+ private def replaceIntegerLiteralWithOrdinal(
+ expression: Expression,
+ canReplaceWithOrdinal: Boolean = true) = expression match {
+ case literal @ Literal(value: Int, IntegerType) if canReplaceWithOrdinal =>
+ CurrentOrigin.withOrigin(literal.origin) { UnresolvedOrdinal(value) }
+ case other => other
+ }
+
/**
* Check plan for any parameters.
* If it finds any throws UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
index 3aec1dd431138..f549f440596e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types.StructType
/**
* Interface for a parser.
@@ -62,4 +63,10 @@ trait ParserInterface extends DataTypeParserInterface {
*/
@throws[ParseException]("Text cannot be parsed to a LogicalPlan")
def parseQuery(sqlText: String): LogicalPlan
+
+ /**
+ * Parse a string to a [[StructType]] as routine parameters, handling default values and comments.
+ */
+ @throws[ParseException]("Text cannot be parsed to routine parameters")
+ def parseRoutineParam(sqlText: String): StructType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index 4377f6b5bc0cf..fe5bdcc00d30a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -62,12 +62,6 @@ object ParserUtils extends SparkParserUtils {
}
}
- /** Get the code that creates the given node. */
- def source(ctx: ParserRuleContext): String = {
- val stream = ctx.getStart.getInputStream
- stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
- }
-
/** Get all the text which comes after the given rule. */
def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
index 62ef65eb11128..18339e81b682f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
@@ -19,7 +19,11 @@ package org.apache.spark.sql.catalyst.plans
import java.util.HashMap
-import org.apache.spark.sql.catalyst.analysis.GetViewColumnByNameAndOrdinal
+import org.apache.spark.sql.catalyst.analysis.{
+ DeduplicateRelations,
+ GetViewColumnByNameAndOrdinal,
+ NormalizeableRelation
+}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
@@ -145,6 +149,11 @@ object NormalizePlan extends PredicateHelper {
.sortBy(_.hashCode())
.reduce(And)
Join(left, right, newJoinType, Some(newCondition), hint)
+ case project: Project
+ if project
+ .getTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION)
+ .isDefined =>
+ project.child
case Project(projectList, child) =>
val projList = projectList
.map { e =>
@@ -168,6 +177,8 @@ object NormalizePlan extends PredicateHelper {
cteIdNormalizer.normalizeDef(cteRelationDef)
case cteRelationRef: CTERelationRef =>
cteIdNormalizer.normalizeRef(cteRelationRef)
+ case normalizeableRelation: NormalizeableRelation =>
+ normalizeableRelation.normalize()
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 6c1a8fa0d773b..1269c9bf8ca1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -787,7 +787,7 @@ object QueryPlan extends PredicateHelper {
* "" -> : None
*/
def generateFieldString(fieldName: String, values: Any): String = values match {
- case iter: Iterable[_] if (iter.size == 0) => s"${fieldName}: []"
+ case iter: Iterable[_] if iter.isEmpty => s"${fieldName}: []"
case iter: Iterable[_] => s"${fieldName} [${iter.size}]: ${iter.mkString("[", ", ", "]")}"
case str: String if (str == null || str.isEmpty) => s"${fieldName}: None"
case str: String => s"${fieldName}: ${str}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
index fd987c47f106e..b8186fa07858a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.rules.RuleId
import org.apache.spark.sql.catalyst.rules.UnknownRuleId
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreePatternBits}
+import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.util.Utils
@@ -155,6 +156,35 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
}
}
+ /**
+ * Similar to [[resolveOperatorsUpWithPruning]], but also applies the given partial function to
+ * all the plans in the subqueries of all nodes. This method is useful when we want to rewrite the
+ * whole plan, including its subqueries, in one go.
+ */
+ def resolveOperatorsUpWithSubqueriesAndPruning(
+ cond: TreePatternBits => Boolean,
+ ruleId: RuleId = UnknownRuleId)(
+ rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
+ val visit: PartialFunction[LogicalPlan, LogicalPlan] =
+ new PartialFunction[LogicalPlan, LogicalPlan] {
+ override def isDefinedAt(x: LogicalPlan): Boolean = true
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ val transformed = plan.transformExpressionsUpWithPruning(
+ t => t.containsPattern(PLAN_EXPRESSION) && cond(t)
+ ) {
+ case subquery: SubqueryExpression =>
+ val newPlan =
+ subquery.plan.resolveOperatorsUpWithSubqueriesAndPruning(cond, ruleId)(rule)
+ subquery.withNewPlan(newPlan)
+ }
+ rule.applyOrElse[LogicalPlan, LogicalPlan](transformed, identity)
+ }
+ }
+
+ resolveOperatorsUpWithPruning(cond, ruleId)(visit)
+ }
+
/** Similar to [[resolveOperatorsUp]], but does it top-down. */
def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
resolveOperatorsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
index 75b2fcd3a5f34..638d20cff928b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
@@ -135,10 +135,10 @@ object NamedParametersSupport {
}
private def toInputParameter(param: ProcedureParameter): InputParameter = {
- val defaultValue = Option(param.defaultValueExpression).map { expr =>
- ResolveDefaultColumns.analyze(param.name, param.dataType, expr, "CALL")
+ val defaultValueExpr = Option(param.defaultValue).map { defaultValue =>
+ ResolveDefaultColumns.analyze(param.name, param.dataType, defaultValue, "CALL")
}
- InputParameter(param.name, defaultValue)
+ InputParameter(param.name, defaultValueExpr)
}
private def defaultRearrange(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index fb1999148d606..60f4453ca23fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -417,9 +417,14 @@ case class Intersect(
private lazy val lazyOutput: Seq[Attribute] = computeOutput()
+ private def computeOutput(): Seq[Attribute] = Intersect.mergeChildOutputs(children.map(_.output))
+}
+
+/** Factory methods for `Intersect` nodes. */
+object Intersect {
/** We don't use right.output because those rows get excluded from the set. */
- private def computeOutput(): Seq[Attribute] =
- left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
+ def mergeChildOutputs(childOutputs: Seq[Seq[Attribute]]): Seq[Attribute] =
+ childOutputs.head.zip(childOutputs.tail.head).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}
}
@@ -451,11 +456,16 @@ case class Except(
private lazy val lazyOutput: Seq[Attribute] = computeOutput()
+ private def computeOutput(): Seq[Attribute] = Except.mergeChildOutputs(children.map(_.output))
+}
+
+/** Factory methods for `Except` nodes. */
+object Except {
/** We don't use right.output because those rows get excluded from the set. */
- private def computeOutput(): Seq[Attribute] = left.output
+ def mergeChildOutputs(childOutputs: Seq[Seq[Attribute]]): Seq[Attribute] = childOutputs.head
}
-/** Factory for constructing new `Union` nodes. */
+/** Factory methods for `Union` nodes. */
object Union {
def apply(left: LogicalPlan, right: LogicalPlan): Union = {
Union (left :: right :: Nil)
@@ -840,20 +850,30 @@ object View {
// For temporary view, we always use captured sql configs
if (activeConf.useCurrentSQLConfigsForView && !isTempView) return activeConf
- val sqlConf = new SQLConf()
// We retain below configs from current session because they are not captured by view
// as optimization configs but they are still needed during the view resolution.
- // TODO: remove this `retainedConfigs` after the `RelationConversions` is moved to
+ // TODO: remove this `retainedHiveConfigs` after the `RelationConversions` is moved to
// optimization phase.
+ val retainedHiveConfigs = Seq(
+ "spark.sql.hive.convertMetastoreParquet",
+ "spark.sql.hive.convertMetastoreOrc",
+ "spark.sql.hive.convertInsertingPartitionedTable",
+ "spark.sql.hive.convertInsertingUnpartitionedTable",
+ "spark.sql.hive.convertMetastoreCtas"
+ )
+
+ val retainedLoggingConfigs = Seq(
+ "spark.sql.planChangeLog.level",
+ "spark.sql.expressionTreeChangeLog.level"
+ )
+
val retainedConfigs = activeConf.getAllConfs.filter { case (key, _) =>
- Seq(
- "spark.sql.hive.convertMetastoreParquet",
- "spark.sql.hive.convertMetastoreOrc",
- "spark.sql.hive.convertInsertingPartitionedTable",
- "spark.sql.hive.convertInsertingUnpartitionedTable",
- "spark.sql.hive.convertMetastoreCtas"
- ).contains(key) || key.startsWith("spark.sql.catalog.")
+ retainedHiveConfigs.contains(key) || retainedLoggingConfigs.contains(key) || key.startsWith(
+ "spark.sql.catalog."
+ )
}
+
+ val sqlConf = new SQLConf()
for ((k, v) <- configs ++ retainedConfigs) {
sqlConf.settings.put(k, v)
}
@@ -1189,7 +1209,14 @@ object Aggregate {
groupingExpression.forall(e => UnsafeRowUtils.isBinaryStable(e.dataType))
}
- def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
+ def supportsObjectHashAggregate(
+ aggregateExpressions: Seq[AggregateExpression],
+ groupingExpressions: Seq[Expression]): Boolean = {
+ // We should not use hash aggregation on binary unstable types.
+ if (groupingExpressions.exists(e => !UnsafeRowUtils.isBinaryStable(e.dataType))) {
+ return false
+ }
+
aggregateExpressions.map(_.aggregateFunction).exists {
case _: TypedImperativeAggregate[_] => true
case _ => false
@@ -1382,6 +1409,8 @@ case class Offset(offsetExpr: Expression, child: LogicalPlan) extends OrderPrese
}
override protected def withNewChildInternal(newChild: LogicalPlan): Offset =
copy(child = newChild)
+
+ override val nodePatterns: Seq[TreePattern] = Seq(OFFSET)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 2f9bf2b52190a..8de801a8ffa19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -178,16 +178,18 @@ case class FlatMapGroupsInPandasWithState(
* @param outputAttrs used to define the output rows
* @param outputMode defines the output mode for the statefulProcessor
* @param timeMode the time mode semantics of the stateful processor for timers and TTL.
+ * @param userFacingDataType the data type of the input and return type in user functions.
* @param child logical plan of the underlying data
* @param initialState logical plan of initial state
* @param initGroupingAttrsLen length of the seq of grouping attributes for initial state dataframe
*/
-case class TransformWithStateInPandas(
+case class TransformWithStateInPySpark(
functionExpr: Expression,
groupingAttributesLen: Int,
outputAttrs: Seq[Attribute],
outputMode: OutputMode,
timeMode: TimeMode,
+ userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value,
child: LogicalPlan,
hasInitialState: Boolean,
initialState: LogicalPlan,
@@ -205,7 +207,7 @@ case class TransformWithStateInPandas(
AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) -- producedAttributes
override protected def withNewChildrenInternal(
- newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPandas =
+ newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPySpark =
copy(child = newLeft, initialState = newRight)
def leftAttributes: Seq[Attribute] = {
@@ -225,6 +227,13 @@ case class TransformWithStateInPandas(
}
}
+object TransformWithStateInPySpark {
+ object UserFacingDataType extends Enumeration {
+ val PYTHON_ROW = Value("python_row")
+ val PANDAS = Value("pandas")
+ }
+}
+
/**
* Flatmap cogroups using a udf: iter(pyarrow.RecordBatch) -> iter(pyarrow.RecordBatch)
* This is used by DataFrame.groupby().cogroup().applyInArrow().
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala
index a0def801ee6f7..2b2f9df6abf52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, UnresolvedException}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ClusterBySpec
-import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
+import org.apache.spark.sql.catalyst.expressions.{Expression, TableConstraint, Unevaluable}
import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, TypeUtils}
import org.apache.spark.sql.connector.catalog.{TableCatalog, TableChange}
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -288,3 +288,27 @@ case class AlterTableCollation(
protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
}
+
+/**
+ * The logical plan of the ALTER TABLE ... ADD CONSTRAINT command.
+ */
+case class AddConstraint(
+ table: LogicalPlan,
+ tableConstraint: TableConstraint) extends AlterTableCommand {
+ override def changes: Seq[TableChange] = Seq.empty
+
+ protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
+}
+
+/**
+ * The logical plan of the ALTER TABLE ... DROP CONSTRAINT command.
+ */
+case class DropConstraint(
+ table: LogicalPlan,
+ name: String,
+ ifExists: Boolean,
+ cascade: Boolean) extends AlterTableCommand {
+ override def changes: Seq[TableChange] = Seq.empty
+
+ protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 1056a30c5f758..e0d44e7d248ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -23,14 +23,15 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils,
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, UnaryExpression, Unevaluable, V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, ReplaceDataProjections, RowDeltaUtils, WriteDeltaProjections}
+import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper}
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.expressions.filter.Predicate
@@ -656,24 +657,6 @@ case class SetNamespaceLocation(
copy(namespace = newChild)
}
-/**
- * The logical plan of the SHOW NAMESPACES command.
- */
-case class ShowNamespaces(
- namespace: LogicalPlan,
- pattern: Option[String],
- override val output: Seq[Attribute] = ShowNamespaces.getOutputAttrs) extends UnaryCommand {
- override def child: LogicalPlan = namespace
- override protected def withNewChildInternal(newChild: LogicalPlan): ShowNamespaces =
- copy(namespace = newChild)
-}
-
-object ShowNamespaces {
- def getOutputAttrs: Seq[Attribute] = {
- Seq(AttributeReference("namespace", StringType, nullable = false)())
- }
-}
-
/**
* The logical plan of the DESCRIBE relation_name command.
*/
@@ -1520,7 +1503,9 @@ case class UnresolvedTableSpec(
comment: Option[String],
collation: Option[String],
serde: Option[SerdeInfo],
- external: Boolean) extends UnaryExpression with Unevaluable with TableSpecBase {
+ external: Boolean,
+ constraints: Seq[TableConstraint])
+ extends UnaryExpression with Unevaluable with TableSpecBase {
override def dataType: DataType =
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113")
@@ -1566,9 +1551,11 @@ case class TableSpec(
comment: Option[String],
collation: Option[String],
serde: Option[SerdeInfo],
- external: Boolean) extends TableSpecBase {
+ external: Boolean,
+ constraints: Seq[Constraint] = Seq.empty) extends TableSpecBase {
def withNewLocation(newLocation: Option[String]): TableSpec = {
- TableSpec(properties, provider, options, newLocation, comment, collation, serde, external)
+ TableSpec(properties, provider, options, newLocation,
+ comment, collation, serde, external, constraints)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala
index 60d82d81df767..e000317a7e157 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala
@@ -108,6 +108,10 @@ case class QueryExecutionMetering() {
}
}
+object QueryExecutionMetering {
+ val INSTANCE: QueryExecutionMetering = QueryExecutionMetering()
+}
+
case class QueryExecutionMetrics(
time: Long,
numRuns: Long,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 76d36fab2096a..c1fbdb710efe0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
object RuleExecutor {
- protected val queryExecutionMeter = QueryExecutionMetering()
+ protected val queryExecutionMeter = QueryExecutionMetering.INSTANCE
/** Dump statistics about time spent running specific rules. */
def dumpTimeSpent(): String = {
@@ -65,7 +65,7 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging {
""".stripMargin
}
- logBasedOnLevel(message())
+ logBasedOnLevel(logLevel)(message())
}
}
}
@@ -83,7 +83,7 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging {
}
}
- logBasedOnLevel(message())
+ logBasedOnLevel(logLevel)(message())
}
}
@@ -101,18 +101,7 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging {
""".stripMargin
// scalastyle:on line.size.limit
- logBasedOnLevel(message)
- }
-
- private def logBasedOnLevel(f: => MessageWithContext): Unit = {
- logLevel match {
- case "TRACE" => logTrace(f.message)
- case "DEBUG" => logDebug(f.message)
- case "INFO" => logInfo(f)
- case "WARN" => logWarning(f)
- case "ERROR" => logError(f)
- case _ => logTrace(f.message)
- }
+ logBasedOnLevel(logLevel)(message)
}
}
@@ -153,7 +142,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
override val maxIterationsSetting: String = null) extends Strategy
/** A batch of rules. */
- protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*)
+ protected[catalyst] case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*)
/** Defines a sequence of rule batches, to be overridden by the implementation. */
protected def batches: Seq[Batch]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index ee5245054bcca..dea7fc33b1849 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -82,6 +82,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.Analyzer$WrapLateralColumnAliasReference" ::
"org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiCombinedTypeCoercionRule" ::
"org.apache.spark.sql.catalyst.analysis.ApplyCharTypePadding" ::
+ "org.apache.spark.sql.catalyst.analysis.CollationTypeCasts" ::
"org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" ::
"org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases" ::
"org.apache.spark.sql.catalyst.analysis.EliminateUnions" ::
@@ -110,6 +111,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.ResolveUpdateEventTimeWatermarkColumn" ::
"org.apache.spark.sql.catalyst.expressions.EliminatePipeOperators" ::
"org.apache.spark.sql.catalyst.expressions.ValidateAndStripPipeExpressions" ::
+ "org.apache.spark.sql.catalyst.analysis.ResolveUnresolvedHaving" ::
// Catalyst Optimizer rules
"org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 9856a26346f6a..3ea32f3cc464f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -131,6 +131,7 @@ object TreePattern extends Enumeration {
val LOCAL_RELATION: Value = Value
val LOGICAL_QUERY_STAGE: Value = Value
val NATURAL_LIKE_JOIN: Value = Value
+ val OFFSET: Value = Value
val OUTER_JOIN: Value = Value
val PROJECT: Value = Value
val PYTHON_DATA_SOURCE: Value = Value
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
index a0ed8e5540397..1084e99731510 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
@@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag
import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, InterpretedOrdering, SortOrder}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, MapData, SQLOrderingUtil}
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, TimeType, VarcharType, VariantType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.{ByteArray, UTF8String, VariantVal}
import org.apache.spark.util.ArrayImplicits._
@@ -54,6 +54,7 @@ object PhysicalDataType {
case DayTimeIntervalType(_, _) => PhysicalLongType
case YearMonthIntervalType(_, _) => PhysicalIntegerType
case DateType => PhysicalIntegerType
+ case _: TimeType => PhysicalLongType
case ArrayType(elementType, containsNull) => PhysicalArrayType(elementType, containsNull)
case StructType(fields) => PhysicalStructType(fields)
case MapType(keyType, valueType, valueContainsNull) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
index 136e8824569e6..25d0f0325520f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
@@ -84,9 +84,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
keys.append(keyNormalized)
values.append(value)
} else {
- if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) {
+ if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION) {
throw QueryExecutionErrors.duplicateMapKeyFoundError(key)
- } else if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
+ } else if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN) {
// Overwrite the previous value, as the policy is last wins.
values(index) = value
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 1f741169898e9..d7cbb9886ba1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -105,6 +105,13 @@ object DateTimeUtils extends SparkDateTimeUtils {
getLocalDateTime(micros, zoneId).getHour
}
+ /**
+ * Returns the hour value of a given TIME (TimeType) value.
+ */
+ def getHoursOfTime(micros: Long): Int = {
+ microsToLocalTime(micros).getHour
+ }
+
/**
* Returns the minute value of a given timestamp value. The timestamp is expressed in
* microseconds since the epoch.
@@ -113,6 +120,13 @@ object DateTimeUtils extends SparkDateTimeUtils {
getLocalDateTime(micros, zoneId).getMinute
}
+ /**
+ * Returns the minute value of a given TIME (TimeType) value.
+ */
+ def getMinutesOfTime(micros: Long): Int = {
+ microsToLocalTime(micros).getMinute
+ }
+
/**
* Returns the second value of a given timestamp value. The timestamp is expressed in
* microseconds since the epoch.
@@ -121,6 +135,12 @@ object DateTimeUtils extends SparkDateTimeUtils {
getLocalDateTime(micros, zoneId).getSecond
}
+ /**
+ * Returns the second value of a given TIME (TimeType) value.
+ */
+ def getSecondsOfTime(micros: Long): Int = {
+ microsToLocalTime(micros).getSecond
+ }
/**
* Returns the seconds part and its fractional part with microseconds.
*/
@@ -128,6 +148,23 @@ object DateTimeUtils extends SparkDateTimeUtils {
Decimal(getMicroseconds(micros, zoneId), 8, 6)
}
+
+ /**
+ * Returns the second value with fraction from a given TIME (TimeType) value.
+ * @param micros
+ * The number of microseconds since the epoch.
+ * @param precision
+ * The time fractional seconds precision, which indicates the number of decimal digits
+ * maintained.
+ */
+ def getSecondsOfTimeWithFraction(micros: Long, precision: Int): Decimal = {
+ val seconds = (micros / MICROS_PER_SECOND) % SECONDS_PER_MINUTE
+ val scaleFactor = math.pow(10, precision).toLong
+ val scaledFraction = (micros % MICROS_PER_SECOND) * scaleFactor / MICROS_PER_SECOND
+ val fraction = scaledFraction.toDouble / scaleFactor
+ Decimal(seconds + fraction, 8, 6)
+ }
+
/**
* Returns local seconds, including fractional parts, multiplied by 1000000.
*
@@ -749,4 +786,40 @@ object DateTimeUtils extends SparkDateTimeUtils {
throw QueryExecutionErrors.invalidDatetimeUnitError("TIMESTAMPDIFF", unit)
}
}
+
+ /**
+ * Converts separate time fields in a long that represents microseconds since the start of
+ * the day
+ * @param hours the hour, from 0 to 23
+ * @param minutes the minute, from 0 to 59
+ * @param secsAndMicros the second, from 0 to 59.999999
+ * @return A time value represented as microseconds since the start of the day
+ */
+ def timeToMicros(hours: Int, minutes: Int, secsAndMicros: Decimal): Long = {
+ try {
+ val unscaledSecFrac = secsAndMicros.toUnscaledLong
+ val fullSecs = Math.floorDiv(unscaledSecFrac, MICROS_PER_SECOND)
+ // The greater than Int.MaxValue check is needed for the case where the full seconds is
+ // outside of the int range. This will overflow when full seconds is converted from
+ // long to int. The overflow could produce an int in the valid seconds range and return a
+ // wrong value. For overflow values outside of the valid seconds range, it would result in a
+ // misleading error message.
+ // The negative check is needed to throw a better error message. In the negative case,
+ // Math.floorDiv gets the next lower negative integer so the full second value in the
+ // original decimal will not match what is in the error message.
+ if (fullSecs > Int.MaxValue || fullSecs < 0) {
+ // Make this error message consistent with what is thrown by LocalTime.of when the
+ // seconds are invalid
+ throw new DateTimeException(
+ s"Invalid value for SecondOfMinute (valid values 0 - 59): ${secsAndMicros.toLong}")
+ }
+
+ val nanos = Math.floorMod(unscaledSecFrac, MICROS_PER_SECOND) * NANOS_PER_MICROS
+ val lt = LocalTime.of(hours, minutes, fullSecs.toInt, nanos.toInt)
+ localTimeToMicros(lt)
+ } catch {
+ case e @ (_: DateTimeException | _: ArithmeticException) =>
+ throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRangeWithoutSuggestion(e)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
index 58b6314e27ade..0706ffdb7c5bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, Optimizer}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
-import org.apache.spark.sql.connector.catalog.{CatalogManager, FunctionCatalog, Identifier, TableCatalog, TableCatalogCapability}
+import org.apache.spark.sql.connector.catalog.{CatalogManager, DefaultValue, FunctionCatalog, Identifier, TableCatalog, TableCatalogCapability}
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
@@ -40,6 +40,7 @@ import org.apache.spark.sql.internal.connector.V1Function
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.util.Utils
/**
* This object contains fields to help process DEFAULT columns.
@@ -120,7 +121,11 @@ object ResolveDefaultColumns extends QueryErrorsBase
schema.exists(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY))) {
val keywords: Array[String] = SQLConf.get.getConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS)
.toLowerCase().split(",").map(_.trim)
- val allowedTableProviders: Array[String] = keywords.map(_.stripSuffix("*"))
+ val allowedTableProviders: Array[String] = if (Utils.isTesting) {
+ "in-memory" +: keywords.map(_.stripSuffix("*"))
+ } else {
+ keywords.map(_.stripSuffix("*"))
+ }
val addColumnExistingTableBannedProviders: Array[String] =
keywords.filter(_.endsWith("*")).map(_.stripSuffix("*"))
val givenTableProvider: String = tableProvider.getOrElse("").toLowerCase()
@@ -279,8 +284,48 @@ object ResolveDefaultColumns extends QueryErrorsBase
throw QueryCompilationErrors.defaultValuesUnresolvedExprError(
statementType, colName, defaultSQL, ex)
}
+ analyze(colName, dataType, parsed, defaultSQL, statementType)
+ }
+
+ /**
+ * Analyzes the connector default value.
+ *
+ * If the default value is defined as a connector expression, Spark first attempts to convert it
+ * to a Catalyst expression. If conversion fails but a SQL string is provided, the SQL is parsed
+ * instead. If only a SQL string is present, it is parsed directly.
+ *
+ * @return the result of the analysis and constant-folding operation
+ */
+ def analyze(
+ colName: String,
+ dataType: DataType,
+ defaultValue: DefaultValue,
+ statementType: String): Expression = {
+ if (defaultValue.getExpression != null) {
+ V2ExpressionUtils.toCatalyst(defaultValue.getExpression) match {
+ case Some(defaultExpr) =>
+ val defaultSQL = Option(defaultValue.getSql).getOrElse(defaultExpr.sql)
+ analyze(colName, dataType, defaultExpr, defaultSQL, statementType)
+
+ case None if defaultValue.getSql != null =>
+ analyze(colName, dataType, defaultValue.getSql, statementType)
+
+ case _ =>
+ throw SparkException.internalError(s"Can't convert $defaultValue to Catalyst")
+ }
+ } else {
+ analyze(colName, dataType, defaultValue.getSql, statementType)
+ }
+ }
+
+ private def analyze(
+ colName: String,
+ dataType: DataType,
+ defaultExpr: Expression,
+ defaultSQL: String,
+ statementType: String): Expression = {
// Check invariants before moving on to analysis.
- if (parsed.containsPattern(PLAN_EXPRESSION)) {
+ if (defaultExpr.containsPattern(PLAN_EXPRESSION)) {
throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions(
statementType, colName, defaultSQL)
}
@@ -288,7 +333,7 @@ object ResolveDefaultColumns extends QueryErrorsBase
// Analyze the parse result.
val plan = try {
val analyzer: Analyzer = DefaultColumnAnalyzer
- val analyzed = analyzer.execute(Project(Seq(Alias(parsed, colName)()), OneRowRelation()))
+ val analyzed = analyzer.execute(Project(Seq(Alias(defaultExpr, colName)()), OneRowRelation()))
analyzer.checkAnalysis(analyzed)
// Eagerly execute finish-analysis and constant-folding rules before checking whether the
// expression is foldable and resolved.
@@ -331,7 +376,7 @@ object ResolveDefaultColumns extends QueryErrorsBase
val expr = Literal.fromSQL(defaultSQL) match {
// EXISTS_DEFAULT will have a cast from analyze() due to coerceDefaultValue
// hence we need to add timezone to the cast if necessary
- case c: Cast if c.needsTimeZone =>
+ case c: Cast if c.child.resolved && c.needsTimeZone =>
c.withTimeZone(SQLConf.get.sessionLocalTimeZone)
case e: Expression => e
}
@@ -459,15 +504,17 @@ object ResolveDefaultColumns extends QueryErrorsBase
* Any type suitable for assigning into a row using the InternalRow.update method.
*/
def getExistenceDefaultValues(schema: StructType): Array[Any] = {
- schema.fields.map { field: StructField =>
- val defaultValue: Option[String] = field.getExistenceDefaultValue()
- defaultValue.map { _: String =>
- val expr = analyzeExistenceDefaultValue(field)
-
- // The expression should be a literal value by this point, possibly wrapped in a cast
- // function. This is enforced by the execution of commands that assign default values.
- expr.eval()
- }.orNull
+ schema.fields.map(getExistenceDefaultValue)
+ }
+
+ def getExistenceDefaultValue(field: StructField): Any = {
+ if (field.hasExistenceDefaultValue) {
+ val expr = analyzeExistenceDefaultValue(field)
+ // The expression should be a literal value by this point, possibly wrapped in a cast
+ // function. This is enforced by the execution of commands that assign default values.
+ expr.eval()
+ } else {
+ null
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveTableConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveTableConstraints.scala
new file mode 100644
index 0000000000000..f8b46b11b3f90
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveTableConstraints.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog, TableCatalogCapability, TableChange}
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+object ResolveTableConstraints {
+ // Validates that the catalog supports create/replace table with constraints.
+ // Throws an exception if unsupported
+ def validateCatalogForTableConstraint(
+ constraints: Seq[Constraint],
+ catalog: TableCatalog,
+ ident: Identifier): Unit = {
+ if (constraints.nonEmpty &&
+ !catalog.capabilities().contains(TableCatalogCapability.SUPPORT_TABLE_CONSTRAINT)) {
+ throw QueryCompilationErrors.unsupportedTableOperationError(
+ catalog, ident, "table constraint")
+ }
+ }
+
+ // Validates that the catalog supports ALTER TABLE ADD/DROP CONSTRAINT operations.
+ // Throws an exception if unsupported.
+ def validateCatalogForTableChange(
+ tableChanges: Seq[TableChange],
+ catalog: TableCatalog,
+ ident: Identifier): Unit = {
+ // Check if the table changes contain table constraints.
+ val hasTableConstraint = tableChanges.exists {
+ case _: TableChange.AddConstraint => true
+ case _: TableChange.DropConstraint => true
+ case _ => false
+ }
+ if (hasTableConstraint &&
+ !catalog.capabilities().contains(TableCatalogCapability.SUPPORT_TABLE_CONSTRAINT)) {
+ throw QueryCompilationErrors.unsupportedTableOperationError(
+ catalog, ident, "table constraint")
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
index b678824574476..9e4e25ba1746c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.xml
import java.io.Writer
import java.sql.Timestamp
+import java.util.Base64
import javax.xml.stream.XMLOutputFactory
import scala.collection.Map
@@ -29,10 +30,11 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, DateTimeUtils, MapData, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.types.variant.VariantUtil
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
class StaxXmlGenerator(
- schema: StructType,
+ schema: DataType,
writer: Writer,
options: XmlOptions,
validateStructure: Boolean = true) {
@@ -124,7 +126,13 @@ class StaxXmlGenerator(
* The row to convert
*/
def write(row: InternalRow): Unit = {
- writeChildElement(options.rowTag, schema, row)
+ schema match {
+ case st: StructType if st.fields.forall(f => options.singleVariantColumn.contains(f.name)) =>
+ // If the top-level field is a StructType with only the single Variant column, we ignore
+ // the single variant column layer and directly write the Variant value under the row tag
+ writeChildElement(options.rowTag, VariantType, row.getVariant(0))
+ case _ => writeChildElement(options.rowTag, schema, row)
+ }
if (indentDisabled) {
gen.writeCharacters("\n")
}
@@ -138,6 +146,8 @@ class StaxXmlGenerator(
case (_, _, _) if name == options.valueTag =>
// If this is meant to be value but in no child, write only a value
writeElement(dt, v, options)
+ case (_, VariantType, v: VariantVal) =>
+ writeVariant(name, v, pos = 0)
case (_, _, _) =>
gen.writeStartElement(name)
writeElement(dt, v, options)
@@ -241,4 +251,168 @@ class StaxXmlGenerator(
}
}
}
+
+ /**
+ * Serialize the single Variant value to XML
+ */
+ def write(v: VariantVal): Unit = {
+ writeVariant(options.rowTag, v, pos = 0)
+ }
+
+ /**
+ * Write a Variant field to XML
+ *
+ * @param name The name of the field
+ * @param v The original Variant entity
+ * @param pos The position in the Variant data array where the field value starts
+ */
+ private def writeVariant(name: String, v: VariantVal, pos: Int): Unit = {
+ VariantUtil.getType(v.getValue, pos) match {
+ case VariantUtil.Type.OBJECT =>
+ writeVariantObject(name, v, pos)
+ case VariantUtil.Type.ARRAY =>
+ writeVariantArray(name, v, pos)
+ case _ =>
+ writeVariantPrimitive(name, v, pos)
+ }
+ }
+
+ /**
+ * Write a Variant object to XML. A Variant object is serialized as an XML element, with the child
+ * fields serialized as XML nodes recursively.
+ *
+ * @param name The name of the object field, which is used as the XML element name
+ * @param v The original Variant entity
+ * @param pos The position in the Variant data array where the object value starts
+ */
+ private def writeVariantObject(name: String, v: VariantVal, pos: Int): Unit = {
+ gen.writeStartElement(name)
+ VariantUtil.handleObject(
+ v.getValue,
+ pos,
+ (size, idSize, offsetSize, idStart, offsetStart, dataStart) => {
+ // Traverse the fields of the object and get their names and positions in the original
+ // Variant
+ val elementInfo = (0 until size).map { i =>
+ val id = VariantUtil.readUnsigned(v.getValue, idStart + idSize * i, idSize)
+ val offset =
+ VariantUtil.readUnsigned(v.getValue, offsetStart + offsetSize * i, offsetSize)
+ val elementPos = dataStart + offset
+ val elementName = VariantUtil.getMetadataKey(v.getMetadata, id)
+ (elementName, elementPos)
+ }
+
+ // Partition the fields of the object into XML attributes and elements
+ val (attributes, elements) = elementInfo.partition {
+ case (f, _) =>
+ // Similar to the reader, we use attributePrefx option to determine whether the field is
+ // an attribute or not.
+ // In addition, we also check if the field is a value tag, in case the value tag also
+ // starts with the attribute prefix.
+ f.startsWith(options.attributePrefix) && f != options.valueTag
+ }
+
+ // We need to write attributes first before the elements.
+ (attributes ++ elements).foreach {
+ case (field, elementPos) =>
+ writeVariant(field, v, elementPos)
+ }
+ }
+ )
+ gen.writeEndElement()
+ }
+
+ /**
+ * Write a Variant array to XML. A Variant array is flattened and written as a sequence of
+ * XML element with the same element name as the array field name.
+ *
+ * @param name The name of the array field
+ * @param v The original Variant entity
+ * @param pos The position in the Variant data array where the array value starts
+ */
+ private def writeVariantArray(name: String, v: VariantVal, pos: Int): Unit = {
+ VariantUtil.handleArray(
+ v.getValue,
+ pos,
+ (size, offsetSize, offsetStart, dataStart) => {
+ // Traverse each item of the array and write each of them as an XML element
+ (0 until size).foreach { i =>
+ val offset =
+ VariantUtil.readUnsigned(v.getValue, offsetStart + offsetSize * i, offsetSize)
+ val elementPos = dataStart + offset
+ // Check if the array element is also of type ARRAY
+ if (VariantUtil.getType(v.getValue, elementPos) == VariantUtil.Type.ARRAY) {
+ // For the case round trip in reading and writing XML files, [[ArrayType]] cannot have
+ // [[ArrayType]] as element type. It always wraps the element with [[StructType]]. So,
+ // this case only can happen when we convert a normal [[DataFrame]] to XML file.
+ // When [[ArrayType]] has [[ArrayType]] as elements, it is confusing what is element
+ // name for XML file.
+ writeVariantArray(options.arrayElementName, v, elementPos)
+ } else {
+ writeVariant(name, v, elementPos)
+ }
+ }
+ }
+ )
+ }
+
+ /**
+ * Write a Variant primitive field to XML
+ *
+ * @param name The name of the field
+ * @param v The original Variant entity
+ * @param pos The position in the Variant data array where the field value starts
+ */
+ private def writeVariantPrimitive(name: String, v: VariantVal, pos: Int): Unit = {
+ val primitiveVal: String = VariantUtil.getType(v.getValue, pos) match {
+ case VariantUtil.Type.NULL => Option(options.nullValue).orNull
+ case VariantUtil.Type.BOOLEAN =>
+ VariantUtil.getBoolean(v.getValue, pos).toString
+ case VariantUtil.Type.LONG =>
+ VariantUtil.getLong(v.getValue, pos).toString
+ case VariantUtil.Type.STRING =>
+ VariantUtil.getString(v.getValue, pos)
+ case VariantUtil.Type.DOUBLE =>
+ VariantUtil.getDouble(v.getValue, pos).toString
+ case VariantUtil.Type.DECIMAL =>
+ VariantUtil.getDecimal(v.getValue, pos).toString
+ case VariantUtil.Type.DATE =>
+ dateFormatter.format(VariantUtil.getLong(v.getValue, pos).toInt)
+ case VariantUtil.Type.TIMESTAMP =>
+ timestampFormatter.format(VariantUtil.getLong(v.getValue, pos))
+ case VariantUtil.Type.TIMESTAMP_NTZ =>
+ timestampNTZFormatter.format(
+ DateTimeUtils.microsToLocalDateTime(VariantUtil.getLong(v.getValue, pos))
+ )
+ case VariantUtil.Type.FLOAT => VariantUtil.getFloat(v.getValue, pos).toString
+ case VariantUtil.Type.BINARY =>
+ Base64.getEncoder.encodeToString(VariantUtil.getBinary(v.getValue, pos))
+ case VariantUtil.Type.UUID => VariantUtil.getUuid(v.getValue, pos).toString
+ case _ =>
+ throw new SparkIllegalArgumentException("invalid variant primitive type for XML")
+ }
+
+ val value = if (primitiveVal == null) options.nullValue else primitiveVal
+
+ // Handle attributes first
+ val isAttribute = name.startsWith(options.attributePrefix) && name != options.valueTag
+ if (isAttribute && primitiveVal != null) {
+ gen.writeAttribute(
+ name.substring(options.attributePrefix.length),
+ value
+ )
+ return
+ }
+
+ // Handle value tags
+ if (name == options.valueTag && primitiveVal != null) {
+ gen.writeCharacters(value)
+ return
+ }
+
+ // Handle child elements
+ gen.writeStartElement(name)
+ if (primitiveVal != null) gen.writeCharacters(value)
+ gen.writeEndElement()
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
index 4b892da9db255..b17e89b536103 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.xml
import java.io.{BufferedReader, CharConversionException, FileNotFoundException, InputStream, InputStreamReader, IOException, StringReader}
import java.nio.charset.{Charset, MalformedInputException}
import java.text.NumberFormat
+import java.util
import java.util.Locale
import javax.xml.stream.{XMLEventReader, XMLStreamException}
import javax.xml.stream.events._
@@ -28,6 +29,7 @@ import javax.xml.validation.Schema
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.Try
+import scala.util.control.Exception.allCatch
import scala.util.control.NonFatal
import scala.xml.SAXException
@@ -45,7 +47,10 @@ import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.types.variant.{Variant, VariantBuilder}
+import org.apache.spark.types.variant.VariantBuilder.FieldEntry
+import org.apache.spark.types.variant.VariantUtil
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
class StaxXmlParser(
schema: StructType,
@@ -138,11 +143,19 @@ class StaxXmlParser(
xsdSchema.foreach { schema =>
schema.newValidator().validate(new StreamSource(new StringReader(xml)))
}
- val parser = StaxXmlParserUtils.filteredReader(xml)
- val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
- val result = Some(convertObject(parser, schema, rootAttributes))
- parser.close()
- result
+ options.singleVariantColumn match {
+ case Some(_) =>
+ // If the singleVariantColumn is specified, parse the entire xml string as a Variant
+ val v = StaxXmlParser.parseVariant(xml, options)
+ Some(InternalRow(v))
+ case _ =>
+ // Otherwise, parse the xml string as Structs
+ val parser = StaxXmlParserUtils.filteredReader(xml)
+ val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
+ val result = Some(convertObject(parser, schema, rootAttributes))
+ parser.close()
+ result
+ }
} catch {
case e: SparkUpgradeException => throw e
case e@(_: RuntimeException | _: XMLStreamException | _: MalformedInputException
@@ -189,6 +202,8 @@ class StaxXmlParser(
case st: StructType => convertObject(parser, st)
case MapType(StringType, vt, _) => convertMap(parser, vt, attributes)
case ArrayType(st, _) => convertField(parser, st, startElementName)
+ case VariantType =>
+ StaxXmlParser.convertVariant(parser, attributes, options)
case _: StringType =>
convertTo(
StaxXmlParserUtils.currentStructureAsString(
@@ -218,6 +233,8 @@ class StaxXmlParser(
value
case (_: Characters, st: StructType) =>
convertObject(parser, st)
+ case (_: Characters, VariantType) =>
+ StaxXmlParser.convertVariant(parser, Array.empty, options)
case (_: Characters, _: StringType) =>
convertTo(
StaxXmlParserUtils.currentStructureAsString(
@@ -374,11 +391,16 @@ class StaxXmlParser(
val newValue = dt match {
case st: StructType =>
convertObjectWithAttributes(parser, st, field, attributes)
+ case VariantType =>
+ StaxXmlParser.convertVariant(parser, attributes, options)
case dt: DataType =>
convertField(parser, dt, field)
}
row(index) = values :+ newValue
+ case VariantType =>
+ row(index) = StaxXmlParser.convertVariant(parser, attributes, options)
+
case dt: DataType =>
row(index) = convertField(parser, dt, field, attributes)
}
@@ -897,4 +919,245 @@ object StaxXmlParser {
curRecord
}
}
+
+ /**
+ * Parse the input XML string as a Variant value
+ */
+ def parseVariant(xml: String, options: XmlOptions): VariantVal = {
+ val parser = StaxXmlParserUtils.filteredReader(xml)
+ val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
+ val v = convertVariant(parser, rootAttributes, options)
+ parser.close()
+ v
+ }
+
+ /**
+ * Parse an XML element from the XML event stream into a Variant.
+ * This method transforms the XML element along with its attributes and child elements
+ * into a hierarchical Variant data structure that preserves the XML structure.
+ *
+ * @param parser The XML event stream reader positioned after the start element
+ * @param attributes The attributes of the current XML element to be included in the Variant
+ * @param options Configuration options that control how XML is parsed into Variants
+ * @return A Variant representing the XML element with its attributes and child content
+ */
+ def convertVariant(
+ parser: XMLEventReader,
+ attributes: Array[Attribute],
+ options: XmlOptions): VariantVal = {
+ val v = convertVariantInternal(parser, attributes, options)
+ new VariantVal(v.getValue, v.getMetadata)
+ }
+
+ private def convertVariantInternal(
+ parser: XMLEventReader,
+ attributes: Array[Attribute],
+ options: XmlOptions): Variant = {
+ // The variant builder for the root startElement
+ val rootBuilder = new VariantBuilder(false)
+ val start = rootBuilder.getWritePos
+
+ // Map to store the variant values of all child fields
+ // Each field could have multiple entries, which means it's an array
+ // The map is sorted by field name, and the ordering is based on the case sensitivity
+ val caseSensitivityOrdering: Ordering[String] = if (SQLConf.get.caseSensitiveAnalysis) {
+ (x: String, y: String) => x.compareTo(y)
+ } else {
+ (x: String, y: String) => x.compareToIgnoreCase(y)
+ }
+ val fieldToVariants = collection.mutable.TreeMap.empty[String, java.util.ArrayList[Variant]](
+ caseSensitivityOrdering
+ )
+
+ // Handle attributes first
+ StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options).foreach {
+ case (f, v) =>
+ val builder = new VariantBuilder(false)
+ appendXMLCharacterToVariant(builder, v, options)
+ val variants = fieldToVariants.getOrElseUpdate(f, new java.util.ArrayList[Variant]())
+ variants.add(builder.result())
+ }
+
+ var shouldStop = false
+ while (!shouldStop) {
+ parser.nextEvent() match {
+ case s: StartElement =>
+ // For each child element, convert it to a variant and keep track of it in
+ // fieldsToVariants
+ val attributes = s.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
+ val field = StaxXmlParserUtils.getName(s.asStartElement.getName, options)
+ val variants = fieldToVariants.getOrElseUpdate(field, new java.util.ArrayList[Variant]())
+ variants.add(convertVariantInternal(parser, attributes, options))
+
+ case c: Characters if !c.isWhiteSpace =>
+ // Treat the character as a value tag field, where we use the [[XMLOptions.valueTag]] as
+ // the field key
+ val builder = new VariantBuilder(false)
+ appendXMLCharacterToVariant(builder, c.getData, options)
+ val variants = fieldToVariants.getOrElseUpdate(
+ options.valueTag,
+ new java.util.ArrayList[Variant]()
+ )
+ variants.add(builder.result())
+
+ case _: EndElement =>
+ if (fieldToVariants.nonEmpty) {
+ val onlyValueTagField = fieldToVariants.keySet.forall(_ == options.valueTag)
+ if (onlyValueTagField) {
+ // If the element only has value tag field, parse the element as a variant primitive
+ rootBuilder.appendVariant(fieldToVariants(options.valueTag).get(0))
+ } else {
+ writeVariantObject(rootBuilder, fieldToVariants)
+ }
+ }
+ shouldStop = true
+
+ case _: EndDocument => shouldStop = true
+
+ case _ => // do nothing
+ }
+ }
+
+ // If the element is empty, we treat it as a Variant null
+ if (rootBuilder.getWritePos == start) {
+ rootBuilder.appendNull()
+ }
+
+ rootBuilder.result()
+ }
+
+ /**
+ * Write a variant object to the variant builder.
+ *
+ * @param builder The variant builder to write to
+ * @param fieldToVariants A map of field names to their corresponding variant values of the object
+ */
+ private def writeVariantObject(
+ builder: VariantBuilder,
+ fieldToVariants: collection.mutable.TreeMap[String, java.util.ArrayList[Variant]]): Unit = {
+ val start = builder.getWritePos
+ val objectFieldEntries = new java.util.ArrayList[FieldEntry]()
+
+ val (lastFieldKey, lastFieldValue) =
+ fieldToVariants.tail.foldLeft(fieldToVariants.head._1, fieldToVariants.head._2) {
+ case ((key, variantVals), (k, v)) =>
+ if (!SQLConf.get.caseSensitiveAnalysis && k.equalsIgnoreCase(key)) {
+ variantVals.addAll(v)
+ (key, variantVals)
+ } else {
+ writeVariantObjectField(key, variantVals, builder, start, objectFieldEntries)
+ (k, v)
+ }
+ }
+
+ writeVariantObjectField(lastFieldKey, lastFieldValue, builder, start, objectFieldEntries)
+
+ // Finish writing the variant object
+ builder.finishWritingObject(start, objectFieldEntries)
+ }
+
+ /**
+ * Write a single field to a variant object
+ *
+ * @param fieldName the name of the object field
+ * @param fieldVariants the variant value of the field. A field could have multiple variant value,
+ * which means it's an array field
+ * @param builder the variant builder
+ * @param objectStart the start position of the variant object in the builder
+ * @param objectFieldEntries a list tracking all fields of the variant object
+ */
+ private def writeVariantObjectField(
+ fieldName: String,
+ fieldVariants: java.util.ArrayList[Variant],
+ builder: VariantBuilder,
+ objectStart: Int,
+ objectFieldEntries: java.util.ArrayList[FieldEntry]): Unit = {
+ val start = builder.getWritePos
+ val fieldId = builder.addKey(fieldName)
+ objectFieldEntries.add(
+ new FieldEntry(fieldName, fieldId, builder.getWritePos - objectStart)
+ )
+
+ val fieldValue = if (fieldVariants.size() > 1) {
+ // If the field has more than one entry, it's an array field. Build a Variant
+ // array as the field value
+ val arrayBuilder = new VariantBuilder(false)
+ val arrayStart = arrayBuilder.getWritePos
+ val offsets = new util.ArrayList[Integer]()
+ fieldVariants.asScala.foreach { v =>
+ offsets.add(arrayBuilder.getWritePos - arrayStart)
+ arrayBuilder.appendVariant(v)
+ }
+ arrayBuilder.finishWritingArray(arrayStart, offsets)
+ arrayBuilder.result()
+ } else {
+ // Otherwise, just use the first variant as the field value
+ fieldVariants.get(0)
+ }
+
+ // Append the field value to the variant builder
+ builder.appendVariant(fieldValue)
+ }
+
+ /**
+ * Convert an XML Character value `s` into a variant value and append the result to `builder`.
+ * The result can only be one of a variant boolean/long/decimal/string. Anything other than
+ * the supported types will be appended to the Variant builder as a string.
+ *
+ * Floating point types (double, float) are not considered to avoid precision loss.
+ */
+ private def appendXMLCharacterToVariant(
+ builder: VariantBuilder,
+ s: String,
+ options: XmlOptions): Unit = {
+ if (s == null || s == options.nullValue) {
+ builder.appendNull()
+ return
+ }
+
+ val value = if (options.ignoreSurroundingSpaces) s.trim() else s
+
+ // Exit early for empty strings
+ if (value.isEmpty) {
+ builder.appendString(value)
+ return
+ }
+
+ // Try parsing the value as boolean first
+ if (value.toLowerCase(Locale.ROOT) == "true") {
+ builder.appendBoolean(true)
+ return
+ }
+ if (value.toLowerCase(Locale.ROOT) == "false") {
+ builder.appendBoolean(false)
+ return
+ }
+
+ // Try parsing the value as a long
+ allCatch opt value.toLong match {
+ case Some(l) =>
+ builder.appendLong(l)
+ return
+ case _ =>
+ }
+
+ // Try parsing the value as decimal
+ val decimalParser = ExprUtils.getDecimalParser(options.locale)
+ allCatch opt decimalParser(value) match {
+ case Some(decimalValue) =>
+ var d = decimalValue
+ if (d.scale() < 0) {
+ d = d.setScale(0)
+ }
+ if (d.scale <= VariantUtil.MAX_DECIMAL16_PRECISION &&
+ d.precision <= VariantUtil.MAX_DECIMAL16_PRECISION) {
+ builder.appendDecimal(d)
+ return
+ }
+ case _ =>
+ }
+
+ // If the character is of other primitive types, parse it as a string
+ builder.appendString(value)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index e2c2d9dbc6d63..132bb1e359479 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -165,6 +165,11 @@ class XmlOptions(
val charset = parameters.getOrElse(ENCODING,
parameters.getOrElse(CHARSET, XmlOptions.DEFAULT_CHARSET))
+ // This option takes in a column name and specifies that the entire XML record should be stored
+ // as a single VARIANT type column in the table with the given column name.
+ // E.g. spark.read.format("xml").option("singleVariantColumn", "colName")
+ val singleVariantColumn = parameters.get(SINGLE_VARIANT_COLUMN)
+
def buildXmlFactory(): XMLInputFactory = {
XMLInputFactory.newInstance()
}
@@ -200,7 +205,7 @@ object XmlOptions extends DataSourceOptions {
val COMPRESSION = newOption("compression")
val MULTI_LINE = newOption("multiLine")
val SAMPLING_RATIO = newOption("samplingRatio")
- val COLUMN_NAME_OF_CORRUPT_RECORD = newOption("columnNameOfCorruptRecord")
+ val COLUMN_NAME_OF_CORRUPT_RECORD = newOption(DataSourceOptions.COLUMN_NAME_OF_CORRUPT_RECORD)
val DATE_FORMAT = newOption("dateFormat")
val TIMESTAMP_FORMAT = newOption("timestampFormat")
val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat")
@@ -208,6 +213,7 @@ object XmlOptions extends DataSourceOptions {
val INDENT = newOption("indent")
val PREFERS_DECIMAL = newOption("prefersDecimal")
val VALIDATE_NAME = newOption("validateName")
+ val SINGLE_VARIANT_COLUMN = newOption("singleVariantColumn")
// Options with alternative
val ENCODING = "encoding"
val CHARSET = "charset"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
index 97cc263c56c5f..d6fa7f58d61cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.catalog
import java.util
-import java.util.Collections
+import java.util.{Collections, Locale}
import scala.jdk.CollectionConverters._
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec}
import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog.TableChange._
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, LiteralValue, Transform}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -296,6 +297,49 @@ private[sql] object CatalogV2Util {
}
}
+ /**
+ * Extracts and validates table constraints from a sequence of table changes.
+ */
+ def collectConstraintChanges(
+ table: Table,
+ changes: Seq[TableChange]): Array[Constraint] = {
+ val constraints = table.constraints()
+
+ def findExistingConstraint(name: String): Option[Constraint] = {
+ constraints.find(_.name.toLowerCase(Locale.ROOT) == name.toLowerCase(Locale.ROOT))
+ }
+
+ changes.foldLeft(constraints) { (constraints, change) =>
+ change match {
+ case add: AddConstraint =>
+ val newConstraint = add.constraint
+ val existingConstraint = findExistingConstraint(newConstraint.name)
+ if (existingConstraint.isDefined) {
+ throw new AnalysisException(
+ errorClass = "CONSTRAINT_ALREADY_EXISTS",
+ messageParameters =
+ Map("constraintName" -> existingConstraint.get.name,
+ "oldConstraint" -> existingConstraint.get.toDDL))
+ }
+ constraints :+ newConstraint
+
+ case drop: DropConstraint =>
+ val existingConstraint = findExistingConstraint(drop.name)
+ if (existingConstraint.isEmpty && !drop.ifExists) {
+ throw new AnalysisException(
+ errorClass = "CONSTRAINT_DOES_NOT_EXIST",
+ messageParameters =
+ Map("constraintName" -> drop.name, "tableName" -> table.name()))
+ }
+ constraints.filterNot(_.name == drop.name)
+
+ case _ =>
+ // ignore non-constraint changes
+ constraints
+ }
+ }.toArray
+ }
+
private def addField(
schema: StructType,
field: StructField,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index fa0a90135934c..b58605ae95420 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -294,6 +294,28 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("tableName" -> toSQLId(tableName)))
}
+ def unsupportedSetOperationOnMapType(mapCol: Attribute, origin: Origin): Throwable = {
+ new AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE",
+ messageParameters = Map(
+ "colName" -> toSQLId(mapCol.name),
+ "dataType" -> toSQLType(mapCol.dataType)
+ ),
+ origin = origin
+ )
+ }
+
+ def unsupportedSetOperationOnVariantType(variantCol: Attribute, origin: Origin): Throwable = {
+ new AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_VARIANT_TYPE",
+ messageParameters = Map(
+ "colName" -> toSQLId(variantCol.name),
+ "dataType" -> toSQLType(variantCol.dataType)
+ ),
+ origin = origin
+ )
+ }
+
def nonPartitionColError(partitionName: String): Throwable = {
new AnalysisException(
errorClass = "NON_PARTITION_COLUMN",
@@ -3179,13 +3201,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
}
def notAllowedToCreatePermanentViewByReferencingTempVarError(
- name: TableIdentifier,
- varName: String): Throwable = {
+ nameParts: Seq[String],
+ varName: Seq[String]): Throwable = {
new AnalysisException(
errorClass = "INVALID_TEMP_OBJ_REFERENCE",
messageParameters = Map(
"obj" -> "VIEW",
- "objName" -> toSQLId(name.nameParts),
+ "objName" -> toSQLId(nameParts),
"tempObj" -> "VARIABLE",
"tempObjName" -> toSQLId(varName)))
}
@@ -3333,10 +3355,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
"config" -> SQLConf.LEGACY_PATH_OPTION_BEHAVIOR.key))
}
- def invalidSingleVariantColumn(): Throwable = {
+ def invalidSingleVariantColumn(schema: DataType): Throwable = {
new AnalysisException(
errorClass = "INVALID_SINGLE_VARIANT_COLUMN",
- messageParameters = Map.empty)
+ messageParameters = Map("schema" -> toSQLType(schema)))
}
def writeWithSaveModeUnsupportedBySourceError(source: String, createMode: String): Throwable = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 9a120827699d7..63f6a907915d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -79,7 +79,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
messageParameters = Map(
"value" -> toSQLValue(t, from),
"sourceType" -> toSQLType(from),
- "targetType" -> toSQLType(to)),
+ "targetType" -> toSQLType(to),
+ "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")),
context = Array.empty,
summary = "")
}
@@ -123,7 +124,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
messageParameters = Map(
"expression" -> toSQLValue(s, StringType),
"sourceType" -> toSQLType(StringType),
- "targetType" -> toSQLType(BooleanType)),
+ "targetType" -> toSQLType(BooleanType),
+ "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")),
context = getQueryContext(context),
summary = getSummary(context))
}
@@ -137,7 +139,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
messageParameters = Map(
"expression" -> toSQLValue(s, StringType),
"sourceType" -> toSQLType(StringType),
- "targetType" -> toSQLType(to)),
+ "targetType" -> toSQLType(to),
+ "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")),
context = getQueryContext(context),
summary = getSummary(context))
}
@@ -276,14 +279,37 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
summary = "")
}
+ def timeParseError(input: String, fmt: Option[String], e: Throwable): SparkDateTimeException = {
+ new SparkDateTimeException(
+ errorClass = "CANNOT_PARSE_TIME",
+ messageParameters = Map(
+ "input" -> toSQLValue(input, StringType),
+ "format" -> toSQLValue(
+ fmt.getOrElse("HH:mm:ss.SSSSSS"),
+ StringType)),
+ context = Array.empty,
+ summary = "",
+ cause = Some(e))
+ }
+
def ansiDateTimeArgumentOutOfRange(e: Exception): SparkDateTimeException = {
new SparkDateTimeException(
- errorClass = "DATETIME_FIELD_OUT_OF_BOUNDS",
+ errorClass = "DATETIME_FIELD_OUT_OF_BOUNDS.WITH_SUGGESTION",
messageParameters = Map(
"rangeMessage" -> e.getMessage,
"ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)),
context = Array.empty,
- summary = "")
+ summary = "",
+ cause = Some(e))
+ }
+
+ def ansiDateTimeArgumentOutOfRangeWithoutSuggestion(e: Throwable): SparkDateTimeException = {
+ new SparkDateTimeException(
+ errorClass = "DATETIME_FIELD_OUT_OF_BOUNDS.WITHOUT_SUGGESTION",
+ messageParameters = Map("rangeMessage" -> e.getMessage),
+ context = Array.empty,
+ summary = "",
+ cause = Some(e))
}
def invalidIntervalWithMicrosecondsAdditionError(): SparkIllegalArgumentException = {
@@ -844,6 +870,19 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
cause = e)
}
+ def arrowDataTypeMismatchError(
+ operation: String,
+ outputTypes: Seq[DataType],
+ actualDataTypes: Seq[DataType]): Throwable = {
+ new SparkException(
+ errorClass = "ARROW_TYPE_MISMATCH",
+ messageParameters = Map(
+ "operation" -> operation,
+ "outputTypes" -> outputTypes.mkString(", "),
+ "actualDataTypes" -> actualDataTypes.mkString(", ")),
+ cause = null)
+ }
+
def cannotReadFilesError(
e: Throwable,
path: String): Throwable = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 7e161fb9b7abe..975c5feb65130 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -21,7 +21,6 @@ import java.util.{Locale, Properties, TimeZone}
import java.util
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
-import java.util.zip.Deflater
import scala.collection.immutable
import scala.jdk.CollectionConverters._
@@ -29,8 +28,10 @@ import scala.util.Try
import scala.util.control.NonFatal
import scala.util.matching.Regex
+import org.apache.avro.file.CodecFactory
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.OutputCommitter
+import org.slf4j.event.Level
import org.apache.spark.{ErrorMessageFormat, SparkConf, SparkContext, SparkException, TaskContext}
import org.apache.spark.internal.Logging
@@ -240,6 +241,15 @@ object SQLConf {
}
}
+ val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS =
+ buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns")
+ .internal()
+ .doc(
+ "When this conf is enabled, AddMetadataColumns rule should only add necessary metadata " +
+ "columns and only if those columns are not already present in the project list.")
+ .booleanConf
+ .createWithDefault(true)
+
val ANALYZER_MAX_ITERATIONS = buildConf("spark.sql.analyzer.maxIterations")
.internal()
.doc("The max number of iterations the analyzer runs.")
@@ -264,6 +274,24 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val ANALYZER_SINGLE_PASS_RESOLVER_ENABLED_TENTATIVELY =
+ buildConf("spark.sql.analyzer.singlePassResolver.enabledTentatively")
+ .internal()
+ .doc(
+ "When true, use the single-pass Resolver instead of the fixed-point Analyzer only if " +
+ "a SQL query or a DataFrame program is fully supported by the single-pass Analyzer. " +
+ "This is an alternative Analyzer framework, which resolves the parsed logical plan in a " +
+ "single post-order traversal. It uses ExpressionResolver to resolve expressions and " +
+ "NameScope to control the visibility of names. In contrast to the current fixed-point " +
+ "framework, subsequent in-tree traversals are disallowed. Most of the fixed-point " +
+ "Analyzer code is reused in the form of specific node transformation functions " +
+ "(AliasResolution.resolve, FunctionResolution.resolveFunction, etc)." +
+ "This feature is currently under development."
+ )
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val ANALYZER_DUAL_RUN_LEGACY_AND_SINGLE_PASS_RESOLVER =
buildConf("spark.sql.analyzer.singlePassResolver.dualRunWithLegacy")
.internal()
@@ -288,6 +316,19 @@ object SQLConf {
.booleanConf
.createWithDefault(Utils.isTesting)
+ val ANALYZER_DUAL_RUN_SAMPLE_RATE =
+ buildConf("spark.sql.analyzer.singlePassResolver.dualRunSampleRate")
+ .internal()
+ .doc(
+ "Represents the rate of queries that will be run in both fixed-point and single-pass " +
+ "mode (dual run). It should be taken into account that the sample rate is not a strict " +
+ "percentage (in tests we don't sample). It is determined whether query should be run in " +
+ "dual run mode by comparing a random value with the value of this flag."
+ )
+ .version("4.1.0")
+ .doubleConf
+ .createWithDefault(if (Utils.isTesting) 1.0 else 0.001)
+
val ANALYZER_SINGLE_PASS_RESOLVER_VALIDATION_ENABLED =
buildConf("spark.sql.analyzer.singlePassResolver.validationEnabled")
.internal()
@@ -316,7 +357,17 @@ object SQLConf {
)
.version("4.0.0")
.booleanConf
- .createWithDefault(Utils.isTesting)
+ .createWithDefault(true)
+
+ val ANALYZER_SINGLE_PASS_RESOLVER_THROW_FROM_RESOLVER_GUARD =
+ buildConf("spark.sql.analyzer.singlePassResolver.throwFromResolverGuard")
+ .internal()
+ .doc(
+ "When set to true, ResolverGuard will throw a descriptive error on unsupported features."
+ )
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
val MULTI_COMMUTATIVE_OP_OPT_THRESHOLD =
buildConf("spark.sql.analyzer.canonicalization.multiCommutativeOpMemoryOptThreshold")
@@ -362,18 +413,16 @@ object SQLConf {
"for using switch statements in InSet must be non-negative and less than or equal to 600")
.createWithDefault(400)
+ private val VALID_LOG_LEVELS: Array[String] = Level.values.map(_.toString)
+
val PLAN_CHANGE_LOG_LEVEL = buildConf("spark.sql.planChangeLog.level")
.internal()
.doc("Configures the log level for logging the change from the original plan to the new " +
- "plan after a rule or batch is applied. The value can be 'trace', 'debug', 'info', " +
- "'warn', or 'error'. The default log level is 'trace'.")
+ s"plan after a rule or batch is applied. The value can be " +
+ s"${VALID_LOG_LEVELS.mkString(", ")}.")
.version("3.1.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValue(logLevel => Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR").contains(logLevel),
- "Invalid value for 'spark.sql.planChangeLog.level'. Valid values are " +
- "'trace', 'debug', 'info', 'warn' and 'error'.")
- .createWithDefault("trace")
+ .enumConf(classOf[Level])
+ .createWithDefault(Level.TRACE)
val PLAN_CHANGE_LOG_RULES = buildConf("spark.sql.planChangeLog.rules")
.internal()
@@ -403,14 +452,10 @@ object SQLConf {
.internal()
.doc("Configures the log level for logging the change from the unresolved expression tree to " +
"the resolved expression tree in the single-pass bottom-up Resolver. The value can be " +
- "'trace', 'debug', 'info', 'warn', or 'error'. The default log level is 'trace'.")
+ s"${VALID_LOG_LEVELS.mkString(", ")}.")
.version("4.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValue(logLevel => Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR").contains(logLevel),
- "Invalid value for 'spark.sql.expressionTreeChangeLog.level'. Valid values are " +
- "'trace', 'debug', 'info', 'warn' and 'error'.")
- .createWithDefault("trace")
+ .enumConf(classOf[Level])
+ .createWithDefault(Level.TRACE)
val LIGHTWEIGHT_PLAN_CHANGE_VALIDATION = buildConf("spark.sql.lightweightPlanChangeValidation")
.internal()
@@ -780,12 +825,10 @@ object SQLConf {
val ADAPTIVE_EXECUTION_LOG_LEVEL = buildConf("spark.sql.adaptive.logLevel")
.internal()
.doc("Configures the log level for adaptive execution logging of plan changes. The value " +
- "can be 'trace', 'debug', 'info', 'warn', or 'error'. The default log level is 'debug'.")
+ s"can be ${VALID_LOG_LEVELS.mkString(", ")}.")
.version("3.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR"))
- .createWithDefault("debug")
+ .enumConf(classOf[Level])
+ .createWithDefault(Level.DEBUG)
val ADVISORY_PARTITION_SIZE_IN_BYTES =
buildConf("spark.sql.adaptive.advisoryPartitionSizeInBytes")
@@ -1151,10 +1194,8 @@ object SQLConf {
"Unix epoch. TIMESTAMP_MILLIS is also standard, but with millisecond precision, which " +
"means Spark has to truncate the microsecond portion of its timestamp value.")
.version("2.3.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(ParquetOutputTimestampType.values.map(_.toString))
- .createWithDefault(ParquetOutputTimestampType.INT96.toString)
+ .enumConf(ParquetOutputTimestampType)
+ .createWithDefault(ParquetOutputTimestampType.INT96)
val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec")
.doc("Sets the compression codec used when writing Parquet files. If either `compression` or " +
@@ -1493,10 +1534,8 @@ object SQLConf {
"attempt to write it to the table properties) and NEVER_INFER (the default mode-- fallback " +
"to using the case-insensitive metastore schema instead of inferring).")
.version("2.1.1")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString))
- .createWithDefault(HiveCaseSensitiveInferenceMode.NEVER_INFER.toString)
+ .enumConf(HiveCaseSensitiveInferenceMode)
+ .createWithDefault(HiveCaseSensitiveInferenceMode.NEVER_INFER)
val HIVE_TABLE_PROPERTY_LENGTH_THRESHOLD =
buildConf("spark.sql.hive.tablePropertyLengthThreshold")
@@ -1534,6 +1573,12 @@ object SQLConf {
.timeConf(TimeUnit.SECONDS)
.createWithDefaultString(s"${5 * 60}")
+ val MAX_BROADCAST_TABLE_SIZE = buildConf("spark.sql.maxBroadcastTableSize")
+ .doc("The maximum table size in bytes that can be broadcast in broadcast joins.")
+ .version("4.1.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(8L << 30)
+
val INTERRUPT_ON_CANCEL = buildConf("spark.sql.execution.interruptOnCancel")
.doc("When true, all running tasks will be interrupted if one cancels a query.")
.version("4.0.0")
@@ -1657,9 +1702,7 @@ object SQLConf {
.doc("The output style used display binary data. Valid values are 'UTF-8', " +
"'BASIC', 'BASE64', 'HEX', and 'HEX_DISCRETE'.")
.version("4.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(BinaryOutputStyle.values.map(_.toString))
+ .enumConf(BinaryOutputStyle)
.createOptional
val PARTITION_COLUMN_TYPE_INFERENCE =
@@ -1798,24 +1841,18 @@ object SQLConf {
.doc("The default storage level of `dataset.cache()`, `catalog.cacheTable()` and " +
"sql query `CACHE TABLE t`.")
.version("4.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(StorageLevelMapper.values.map(_.name()).toSet)
- .createWithDefault(StorageLevelMapper.MEMORY_AND_DISK.name())
+ .enumConf(classOf[StorageLevelMapper])
+ .createWithDefault(StorageLevelMapper.MEMORY_AND_DISK)
val DATAFRAME_CACHE_LOG_LEVEL = buildConf("spark.sql.dataframeCache.logLevel")
.internal()
.doc("Configures the log level of Dataframe cache operations, including adding and removing " +
- "entries from Dataframe cache, hit and miss on cache application. The default log " +
- "level is 'trace'. This log should only be used for debugging purposes and not in the " +
- "production environment, since it generates a large amount of logs.")
+ "entries from Dataframe cache, hit and miss on cache application. This log should only be " +
+ "used for debugging purposes and not in the production environment, since it generates a " +
+ "large amount of logs.")
.version("4.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValue(logLevel => Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR").contains(logLevel),
- "Invalid value for 'spark.sql.dataframeCache.logLevel'. Valid values are " +
- "'trace', 'debug', 'info', 'warn' and 'error'.")
- .createWithDefault("trace")
+ .enumConf(classOf[Level])
+ .createWithDefault(Level.TRACE)
val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled")
.internal()
@@ -1994,6 +2031,7 @@ object SQLConf {
.createWithDefault(100)
val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode")
+ .internal()
.doc("This config determines the fallback behavior of several codegen generators " +
"during tests. `FALLBACK` means trying codegen first and then falling back to " +
"interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " +
@@ -2001,10 +2039,8 @@ object SQLConf {
"this configuration is only for the internal usage, and NOT supposed to be set by " +
"end users.")
.version("2.4.0")
- .internal()
- .stringConf
- .checkValues(CodegenObjectFactoryMode.values.map(_.toString))
- .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString)
+ .enumConf(CodegenObjectFactoryMode)
+ .createWithDefault(CodegenObjectFactoryMode.FALLBACK)
val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback")
.internal()
@@ -2014,6 +2050,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val CODEGEN_LOG_LEVEL = buildConf("spark.sql.codegen.logLevel")
+ .internal()
+ .doc("Configures the log level for logging of codegen. " +
+ s"The value can be ${VALID_LOG_LEVELS.mkString(", ")}.")
+ .version("4.1.0")
+ .enumConf(classOf[Level])
+ .createWithDefault(Level.DEBUG)
+
val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines")
.internal()
.doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.")
@@ -2209,6 +2253,13 @@ object SQLConf {
.checkValue(_ > 0, "Must be greater than 0")
.createWithDefault(Math.max(Runtime.getRuntime.availableProcessors() / 4, 1))
+ val STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT =
+ buildConf("spark.sql.streaming.stateStore.maintenanceShutdownTimeout")
+ .internal()
+ .doc("Timeout in seconds for maintenance pool operations to complete on shutdown")
+ .timeConf(TimeUnit.SECONDS)
+ .createWithDefault(300L)
+
val STATE_SCHEMA_CHECK_ENABLED =
buildConf("spark.sql.streaming.stateStore.stateSchemaCheck")
.doc("When true, Spark will validate the state schema against schema on existing state and " +
@@ -2217,6 +2268,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
+ buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
+ .internal()
+ .doc("Minimum number of state store delta files that needs to be generated before they " +
+ "consolidated into snapshots.")
+ .version("2.0.0")
+ .intConf
+ .createWithDefault(10)
+
val STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT =
buildConf("spark.sql.streaming.stateStore.numStateStoreInstanceMetricsToReport")
.internal()
@@ -2225,20 +2285,11 @@ object SQLConf {
"per stateful operator. Instance metrics are selected based on metric-specific ordering " +
"to minimize noise in the progress report."
)
- .version("4.0.0")
+ .version("4.1.0")
.intConf
.checkValue(k => k >= 0, "Must be greater than or equal to 0")
.createWithDefault(5)
- val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
- buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
- .internal()
- .doc("Minimum number of state store delta files that needs to be generated before they " +
- "consolidated into snapshots.")
- .version("2.0.0")
- .intConf
- .createWithDefault(10)
-
val STATE_STORE_FORMAT_VALIDATION_ENABLED =
buildConf("spark.sql.streaming.stateStore.formatValidation.enabled")
.internal()
@@ -2249,6 +2300,70 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG =
+ buildConf("spark.sql.streaming.stateStore.multiplierForMinVersionDiffToLog")
+ .internal()
+ .doc(
+ "Determines the version threshold for logging warnings when a state store falls behind. " +
+ "The coordinator logs a warning when the store's uploaded snapshot version trails the " +
+ "query's latest version by the configured number of deltas needed to create a snapshot, " +
+ "times this multiplier."
+ )
+ .version("4.1.0")
+ .longConf
+ .checkValue(k => k >= 1L, "Must be greater than or equal to 1")
+ .createWithDefault(5L)
+
+ val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG =
+ buildConf("spark.sql.streaming.stateStore.multiplierForMinTimeDiffToLog")
+ .internal()
+ .doc(
+ "Determines the time threshold for logging warnings when a state store falls behind. " +
+ "The coordinator logs a warning when the store's uploaded snapshot timestamp trails the " +
+ "current time by the configured maintenance interval, times this multiplier."
+ )
+ .version("4.1.0")
+ .longConf
+ .checkValue(k => k >= 1L, "Must be greater than or equal to 1")
+ .createWithDefault(10L)
+
+ val STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG =
+ buildConf("spark.sql.streaming.stateStore.coordinatorReportSnapshotUploadLag")
+ .internal()
+ .doc(
+ "When enabled, the state store coordinator will report state stores whose snapshot " +
+ "have not been uploaded for some time. See the conf snapshotLagReportInterval for " +
+ "the minimum time between reports, and the conf multiplierForMinVersionDiffToLog " +
+ "and multiplierForMinTimeDiffToLog for the logging thresholds."
+ )
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL =
+ buildConf("spark.sql.streaming.stateStore.snapshotLagReportInterval")
+ .internal()
+ .doc(
+ "The minimum amount of time between the state store coordinator's reports on " +
+ "state store instances trailing behind in snapshot uploads."
+ )
+ .version("4.1.0")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefault(TimeUnit.MINUTES.toMillis(5))
+
+ val STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT =
+ buildConf("spark.sql.streaming.stateStore.maxLaggingStoresToReport")
+ .internal()
+ .doc(
+ "Maximum number of state stores the coordinator will report as trailing in " +
+ "snapshot uploads. Stores are selected based on the most lagging behind in " +
+ "snapshot version."
+ )
+ .version("4.1.0")
+ .intConf
+ .checkValue(k => k >= 0, "Must be greater than or equal to 0")
+ .createWithDefault(5)
+
val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
.internal()
@@ -2372,6 +2487,18 @@ object SQLConf {
.stringConf
.createWithDefault(CompressionCodec.LZ4)
+ val STATE_STORE_UNLOAD_ON_COMMIT =
+ buildConf("spark.sql.streaming.stateStore.unloadOnCommit")
+ .internal()
+ .doc("When true, Spark will synchronously run maintenance and then close each StateStore " +
+ "instance on task completion. This removes the overhead of keeping every StateStore " +
+ "loaded indefinitely, at the cost of having to reload each StateStore every batch. " +
+ "Stateful applications that are failing due to resource exhaustion or that use " +
+ "dynamic allocation may benefit from enabling this.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
val CHECKPOINT_RENAMEDFILE_CHECK_ENABLED =
buildConf("spark.sql.streaming.checkpoint.renamedFileCheck.enabled")
.doc("When true, Spark will validate if renamed checkpoint file exists.")
@@ -2422,10 +2549,11 @@ object SQLConf {
.internal()
.doc("State format version used by streaming join operations in a streaming query. " +
"State between versions are tend to be incompatible, so state format version shouldn't " +
- "be modified after running.")
+ "be modified after running. Version 3 uses a single state store with virtual column " +
+ "families instead of four stores and is only supported with RocksDB.")
.version("3.0.0")
.intConf
- .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
+ .checkValue(v => Set(1, 2, 3).contains(v), "Valid versions are 1, 2, and 3")
.createWithDefault(2)
val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION =
@@ -3271,7 +3399,7 @@ object SQLConf {
.doc("(Deprecated since Spark 3.0, please set 'spark.sql.execution.arrow.pyspark.enabled'.)")
.version("2.3.0")
.booleanConf
- .createWithDefault(true)
+ .createWithDefault(false)
val ARROW_PYSPARK_EXECUTION_ENABLED =
buildConf("spark.sql.execution.arrow.pyspark.enabled")
@@ -3404,6 +3532,19 @@ object SQLConf {
.intConf
.createWithDefault(10000)
+ val ARROW_EXECUTION_MAX_RECORDS_PER_OUTPUT_BATCH =
+ buildConf("spark.sql.execution.arrow.maxRecordsPerOutputBatch")
+ .doc("When using Apache Arrow, limit the maximum number of records that can be output " +
+ "in a single ArrowRecordBatch to the downstream operator. If set to zero or negative " +
+ "there is no limit. Note that the complete ArrowRecordBatch is actually created but " +
+ "the number of records is limited when sending it to the downstream operator. This is " +
+ "used to avoid large batches being sent to the downstream operator including " +
+ "the columnar-based operator implemented by third-party libraries.")
+ .version("4.1.0")
+ .internal()
+ .intConf
+ .createWithDefault(-1)
+
val ARROW_EXECUTION_MAX_BYTES_PER_BATCH =
buildConf("spark.sql.execution.arrow.maxBytesPerBatch")
.internal()
@@ -3423,11 +3564,15 @@ object SQLConf {
"than zero and less than INT_MAX.")
.createWithDefaultString("256MB")
- val ARROW_TRANSFORM_WITH_STATE_IN_PANDAS_MAX_RECORDS_PER_BATCH =
- buildConf("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch")
- .doc("When using TransformWithStateInPandas, limit the maximum number of state records " +
- "that can be written to a single ArrowRecordBatch in memory.")
+ val ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH =
+ buildConf("spark.sql.execution.arrow.transformWithStateInPySpark.maxStateRecordsPerBatch")
+ .doc("When using TransformWithState in PySpark (both Python Row and Pandas), limit " +
+ "the maximum number of state records that can be written to a single ArrowRecordBatch " +
+ "in memory.")
.version("4.0.0")
+ // NOTE: This config was released already in Spark 4.0.0, so we should not remove the
+ // support of this.
+ .withAlternative("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch")
.intConf
.createWithDefault(10000)
@@ -3506,13 +3651,22 @@ object SQLConf {
// show full stacktrace in tests but hide in production by default.
.createWithDefault(!Utils.isTesting)
+ val PYSPARK_ARROW_VALIDATE_SCHEMA =
+ buildConf("spark.sql.execution.arrow.pyspark.validateSchema.enabled")
+ .doc(
+ "When true, validate the schema of Arrow batches returned by mapInArrow, mapInPandas " +
+ "and DataSource against the expected schema to ensure that they are compatible.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
val PYTHON_UDF_ARROW_ENABLED =
buildConf("spark.sql.execution.pythonUDF.arrow.enabled")
.doc("Enable Arrow optimization in regular Python UDFs. This optimization " +
"can only be enabled when the given function takes at least one argument.")
.version("3.4.0")
.booleanConf
- .createWithDefault(true)
+ .createWithDefault(false)
val PYTHON_UDF_ARROW_CONCURRENCY_LEVEL =
buildConf("spark.sql.execution.pythonUDF.arrow.concurrency.level")
@@ -3525,6 +3679,15 @@ object SQLConf {
" must be more than one.")
.createOptional
+ val PYTHON_UDF_ARROW_FALLBACK_ON_UDT =
+ buildConf("spark.sql.execution.pythonUDF.arrow.legacy.fallbackOnUDT")
+ .internal()
+ .doc("When true, Arrow-optimized Python UDF will fallback to the regular UDF when " +
+ "its input or output is UDT.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
val PYTHON_TABLE_UDF_ARROW_ENABLED =
buildConf("spark.sql.execution.pythonUDTF.arrow.enabled")
.doc("Enable Arrow optimization for Python UDTFs.")
@@ -3766,10 +3929,8 @@ object SQLConf {
"dataframe.write.option(\"partitionOverwriteMode\", \"dynamic\").save(path)."
)
.version("2.3.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(PartitionOverwriteMode.values.map(_.toString))
- .createWithDefault(PartitionOverwriteMode.STATIC.toString)
+ .enumConf(PartitionOverwriteMode)
+ .createWithDefault(PartitionOverwriteMode.STATIC)
object StoreAssignmentPolicy extends Enumeration {
val ANSI, LEGACY, STRICT = Value
@@ -3791,10 +3952,8 @@ object SQLConf {
"not allowed."
)
.version("3.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(StoreAssignmentPolicy.values.map(_.toString))
- .createWithDefault(StoreAssignmentPolicy.ANSI.toString)
+ .enumConf(StoreAssignmentPolicy)
+ .createWithDefault(StoreAssignmentPolicy.ANSI)
val ANSI_ENABLED = buildConf(SqlApiConfHelper.ANSI_ENABLED_KEY)
.doc("When true, Spark SQL uses an ANSI compliant dialect instead of being Hive compliant. " +
@@ -4230,8 +4389,8 @@ object SQLConf {
"The default value is -1 which corresponds to 6 level in the current implementation.")
.version("2.4.0")
.intConf
- .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION)
- .createOptional
+ .checkValues((1 to 9).toSet + CodecFactory.DEFAULT_DEFLATE_LEVEL)
+ .createWithDefault(CodecFactory.DEFAULT_DEFLATE_LEVEL)
val AVRO_XZ_LEVEL = buildConf("spark.sql.avro.xz.level")
.doc("Compression level for the xz codec used in writing of AVRO files. " +
@@ -4240,14 +4399,13 @@ object SQLConf {
.version("4.0.0")
.intConf
.checkValue(v => v > 0 && v <= 9, "The value must be in the range of from 1 to 9 inclusive.")
- .createOptional
+ .createWithDefault(CodecFactory.DEFAULT_XZ_LEVEL)
val AVRO_ZSTANDARD_LEVEL = buildConf("spark.sql.avro.zstandard.level")
- .doc("Compression level for the zstandard codec used in writing of AVRO files. " +
- "The default value is 3.")
+ .doc("Compression level for the zstandard codec used in writing of AVRO files. ")
.version("4.0.0")
.intConf
- .createOptional
+ .createWithDefault(CodecFactory.DEFAULT_ZSTANDARD_LEVEL)
val AVRO_ZSTANDARD_BUFFER_POOL_ENABLED = buildConf("spark.sql.avro.zstandard.bufferPool.enabled")
.doc("If true, enable buffer pool of ZSTD JNI library when writing of AVRO files")
@@ -4448,10 +4606,8 @@ object SQLConf {
"Before the 3.4.0 release, Spark only supports the TIMESTAMP WITH " +
"LOCAL TIME ZONE type.")
.version("3.4.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(TimestampTypes.values.map(_.toString))
- .createWithDefault(TimestampTypes.TIMESTAMP_LTZ.toString)
+ .enumConf(TimestampTypes)
+ .createWithDefault(TimestampTypes.TIMESTAMP_LTZ)
val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled")
.doc("If the configuration property is set to true, java.time.Instant and " +
@@ -4526,10 +4682,8 @@ object SQLConf {
"fails if duplicated map keys are detected. When LAST_WIN, the map key that is inserted " +
"at last takes precedence.")
.version("3.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(MapKeyDedupPolicy.values.map(_.toString))
- .createWithDefault(MapKeyDedupPolicy.EXCEPTION.toString)
+ .enumConf(MapKeyDedupPolicy)
+ .createWithDefault(MapKeyDedupPolicy.EXCEPTION)
val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.doLooseUpcast")
.internal()
@@ -4545,10 +4699,24 @@ object SQLConf {
"The default is CORRECTED, inner CTE definitions take precedence. This config " +
"will be removed in future versions and CORRECTED will be the only behavior.")
.version("3.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
- .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)
+ .enumConf(LegacyBehaviorPolicy)
+ .createWithDefault(LegacyBehaviorPolicy.CORRECTED)
+
+ val CTE_RECURSION_LEVEL_LIMIT = buildConf("spark.sql.cteRecursionLevelLimit")
+ .doc("Maximum level of recursion that is allowed while executing a recursive CTE definition." +
+ "If a query does not get exhausted before reaching this limit it fails. Use -1 for " +
+ "unlimited.")
+ .version("4.1.0")
+ .intConf
+ .createWithDefault(100)
+
+ val CTE_RECURSION_ROW_LIMIT = buildConf("spark.sql.cteRecursionRowLimit")
+ .doc("Maximum number of rows that can be returned when executing a recursive CTE definition." +
+ "If a query does not get exhausted before reaching this limit it fails. Use -1 for " +
+ "unlimited.")
+ .version("4.1.0")
+ .intConf
+ .createWithDefault(1000000)
val LEGACY_INLINE_CTE_IN_COMMANDS = buildConf("spark.sql.legacy.inlineCTEInCommands")
.internal()
@@ -4567,10 +4735,8 @@ object SQLConf {
"When set to EXCEPTION, RuntimeException is thrown when we will get different " +
"results. The default is CORRECTED.")
.version("3.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
- .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)
+ .enumConf(LegacyBehaviorPolicy)
+ .createWithDefault(LegacyBehaviorPolicy.CORRECTED)
val LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC =
buildConf("spark.sql.legacy.followThreeValuedLogicInArrayExists")
@@ -4694,6 +4860,13 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val PYTHON_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.python.filterPushdown.enabled")
+ .doc("When true, enable filter pushdown to Python datasource, at the cost of running " +
+ "Python worker one additional time during planning.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
val CSV_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.csv.filterPushdown.enabled")
.doc("When true, enable filter pushdown to CSV datasource.")
.version("3.0.0")
@@ -4851,10 +5024,8 @@ object SQLConf {
"When EXCEPTION, Spark will fail the writing if it sees ancient " +
"timestamps that are ambiguous between the two calendars.")
.version("3.1.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
- .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)
+ .enumConf(LegacyBehaviorPolicy)
+ .createWithDefault(LegacyBehaviorPolicy.CORRECTED)
val PARQUET_REBASE_MODE_IN_WRITE =
buildConf("spark.sql.parquet.datetimeRebaseModeInWrite")
@@ -4868,10 +5039,8 @@ object SQLConf {
"TIMESTAMP_MILLIS, TIMESTAMP_MICROS. The INT96 type has the separate config: " +
s"${PARQUET_INT96_REBASE_MODE_IN_WRITE.key}.")
.version("3.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
- .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)
+ .enumConf(LegacyBehaviorPolicy)
+ .createWithDefault(LegacyBehaviorPolicy.CORRECTED)
val PARQUET_INT96_REBASE_MODE_IN_READ =
buildConf("spark.sql.parquet.int96RebaseModeInRead")
@@ -4883,10 +5052,8 @@ object SQLConf {
"timestamps that are ambiguous between the two calendars. This config is only effective " +
"if the writer info (like Spark, Hive) of the Parquet files is unknown.")
.version("3.1.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
- .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)
+ .enumConf(LegacyBehaviorPolicy)
+ .createWithDefault(LegacyBehaviorPolicy.CORRECTED)
val PARQUET_REBASE_MODE_IN_READ =
buildConf("spark.sql.parquet.datetimeRebaseModeInRead")
@@ -4902,10 +5069,8 @@ object SQLConf {
s"${PARQUET_INT96_REBASE_MODE_IN_READ.key}.")
.version("3.0.0")
.withAlternative("spark.sql.legacy.parquet.datetimeRebaseModeInRead")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
- .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)
+ .enumConf(LegacyBehaviorPolicy)
+ .createWithDefault(LegacyBehaviorPolicy.CORRECTED)
val AVRO_REBASE_MODE_IN_WRITE =
buildConf("spark.sql.avro.datetimeRebaseModeInWrite")
@@ -4916,10 +5081,8 @@ object SQLConf {
"When EXCEPTION, Spark will fail the writing if it sees " +
"ancient dates/timestamps that are ambiguous between the two calendars.")
.version("3.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
- .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)
+ .enumConf(LegacyBehaviorPolicy)
+ .createWithDefault(LegacyBehaviorPolicy.CORRECTED)
val AVRO_REBASE_MODE_IN_READ =
buildConf("spark.sql.avro.datetimeRebaseModeInRead")
@@ -4931,10 +5094,8 @@ object SQLConf {
"ancient dates/timestamps that are ambiguous between the two calendars. This config is " +
"only effective if the writer info (like Spark, Hive) of the Avro files is unknown.")
.version("3.0.0")
- .stringConf
- .transform(_.toUpperCase(Locale.ROOT))
- .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
- .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString)
+ .enumConf(LegacyBehaviorPolicy)
+ .createWithDefault(LegacyBehaviorPolicy.CORRECTED)
val SCRIPT_TRANSFORMATION_EXIT_TIMEOUT =
buildConf("spark.sql.scriptTransformation.exitTimeoutInSeconds")
@@ -5069,6 +5230,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val READ_FILE_SOURCE_TABLE_CACHE_IGNORE_OPTIONS =
+ buildConf("spark.sql.legacy.readFileSourceTableCacheIgnoreOptions")
+ .internal()
+ .doc("When set to true, reading from file source table caches the first query plan and " +
+ "ignores subsequent changes in query options. Otherwise, query options will be applied " +
+ "to the cached plan and may produce different results.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val READ_SIDE_CHAR_PADDING = buildConf("spark.sql.readSideCharPadding")
.doc("When true, Spark applies string padding when reading CHAR type columns/fields, " +
"in addition to the write-side padding. This config is true by default to better enforce " +
@@ -5265,9 +5436,8 @@ object SQLConf {
"STANDARD includes an additional JSON field `message`. This configuration property " +
"influences on error messages of Thrift Server and SQL CLI while running queries.")
.version("3.4.0")
- .stringConf.transform(_.toUpperCase(Locale.ROOT))
- .checkValues(ErrorMessageFormat.values.map(_.toString))
- .createWithDefault(ErrorMessageFormat.PRETTY.toString)
+ .enumConf(ErrorMessageFormat)
+ .createWithDefault(ErrorMessageFormat.PRETTY)
val LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED =
buildConf("spark.sql.lateralColumnAlias.enableImplicitResolution")
@@ -5587,6 +5757,29 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val VARIABLES_UNDER_IDENTIFIER_IN_VIEW =
+ buildConf("spark.sql.legacy.allowSessionVariableInPersistedView")
+ .internal()
+ .doc(
+ "When set to true, variables can be found under identifiers in a view query. Throw " +
+ "otherwise."
+ )
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val RUN_COLLATION_TYPE_CASTS_BEFORE_ALIAS_ASSIGNMENT =
+ buildConf("spark.sql.runCollationTypeCastsBeforeAliasAssignment.enabled")
+ .internal()
+ .doc(
+ "When set to true, rules like ResolveAliases or ResolveAggregateFunctions will run " +
+ "CollationTypeCasts before alias assignment. This is necessary for correct alias " +
+ "generation."
+ )
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
/**
* Holds information about keys that have been deprecated.
*
@@ -5749,13 +5942,13 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def optimizerInSetSwitchThreshold: Int = getConf(OPTIMIZER_INSET_SWITCH_THRESHOLD)
- def planChangeLogLevel: String = getConf(PLAN_CHANGE_LOG_LEVEL)
+ def planChangeLogLevel: Level = getConf(PLAN_CHANGE_LOG_LEVEL)
def planChangeRules: Option[String] = getConf(PLAN_CHANGE_LOG_RULES)
def planChangeBatches: Option[String] = getConf(PLAN_CHANGE_LOG_BATCHES)
- def expressionTreeChangeLogLevel: String = getConf(EXPRESSION_TREE_CHANGE_LOG_LEVEL)
+ def expressionTreeChangeLogLevel: Level = getConf(EXPRESSION_TREE_CHANGE_LOG_LEVEL)
def dynamicPartitionPruningEnabled: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_ENABLED)
@@ -5785,6 +5978,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def numStateStoreInstanceMetricsToReport: Int =
getConf(STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+ def stateStoreMaintenanceShutdownTimeout: Long = getConf(STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT)
+
def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)
@@ -5792,6 +5987,21 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def stateStoreSkipNullsForStreamStreamJoins: Boolean =
getConf(STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS)
+ def stateStoreCoordinatorMultiplierForMinVersionDiffToLog: Long =
+ getConf(STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG)
+
+ def stateStoreCoordinatorMultiplierForMinTimeDiffToLog: Long =
+ getConf(STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG)
+
+ def stateStoreCoordinatorReportSnapshotUploadLag: Boolean =
+ getConf(STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG)
+
+ def stateStoreCoordinatorSnapshotLagReportInterval: Long =
+ getConf(STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL)
+
+ def stateStoreCoordinatorMaxLaggingStoresToReport: Int =
+ getConf(STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT)
+
def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)
def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED)
@@ -5895,7 +6105,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)
- def adaptiveExecutionLogLevel: String = getConf(ADAPTIVE_EXECUTION_LOG_LEVEL)
+ def adaptiveExecutionLogLevel: Level = getConf(ADAPTIVE_EXECUTION_LOG_LEVEL)
def fetchShuffleBlocksInBatch: Boolean = getConf(FETCH_SHUFFLE_BLOCKS_IN_BATCH)
@@ -5910,6 +6120,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY)
+ def stateStoreUnloadOnCommit: Boolean = getConf(STATE_STORE_UNLOAD_ON_COMMIT)
+
def streamingMaintenanceInterval: Long = getConf(STREAMING_MAINTENANCE_INTERVAL)
def stateStoreCompressionCodec: String = getConf(STATE_STORE_COMPRESSION_CODEC)
@@ -5966,7 +6178,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def filesourcePartitionFileCacheSize: Long = getConf(HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE)
def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value =
- HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE))
+ getConf(HIVE_CASE_SENSITIVE_INFERENCE)
def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT)
@@ -5980,10 +6192,12 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)
- def codegenFactoryMode: String = getConf(CODEGEN_FACTORY_MODE)
+ def codegenFactoryMode: CodegenObjectFactoryMode.Value = getConf(CODEGEN_FACTORY_MODE)
def codegenComments: Boolean = getConf(StaticSQLConf.CODEGEN_COMMENTS)
+ def codegenLogLevel: Level = getConf(CODEGEN_LOG_LEVEL)
+
def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)
def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)
@@ -6054,9 +6268,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def legacyPostgresDatetimeMappingEnabled: Boolean =
getConf(LEGACY_POSTGRES_DATETIME_MAPPING_ENABLED)
- override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = {
- LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY))
- }
+ override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value =
+ getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)
def broadcastHashJoinOutputPartitioningExpandLimit: Int =
getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT)
@@ -6110,9 +6323,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def isParquetINT96TimestampConversion: Boolean = getConf(PARQUET_INT96_TIMESTAMP_CONVERSION)
- def parquetOutputTimestampType: ParquetOutputTimestampType.Value = {
- ParquetOutputTimestampType.withName(getConf(PARQUET_OUTPUT_TIMESTAMP_TYPE))
- }
+ def parquetOutputTimestampType: ParquetOutputTimestampType.Value =
+ getConf(PARQUET_OUTPUT_TIMESTAMP_TYPE)
def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT)
@@ -6131,6 +6343,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
if (timeoutValue < 0) Long.MaxValue else timeoutValue
}
+ def maxBroadcastTableSizeInBytes: Long = getConf(MAX_BROADCAST_TABLE_SIZE)
+
def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME)
def convertCTAS: Boolean = getConf(CONVERT_CTAS)
@@ -6209,9 +6423,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def viewSchemaCompensation: Boolean = getConf(VIEW_SCHEMA_COMPENSATION)
def defaultCacheStorageLevel: StorageLevel =
- StorageLevel.fromString(getConf(DEFAULT_CACHE_STORAGE_LEVEL))
+ StorageLevel.fromString(getConf(DEFAULT_CACHE_STORAGE_LEVEL).name())
- def dataframeCacheLogLevel: String = getConf(DATAFRAME_CACHE_LOG_LEVEL)
+ def dataframeCacheLogLevel: Level = getConf(DATAFRAME_CACHE_LOG_LEVEL)
def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
@@ -6321,6 +6535,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def pythonUDFArrowConcurrencyLevel: Option[Int] = getConf(PYTHON_UDF_ARROW_CONCURRENCY_LEVEL)
+ def pythonUDFArrowFallbackOnUDT: Boolean = getConf(PYTHON_UDF_ARROW_FALLBACK_ON_UDT)
+
def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS)
def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)
@@ -6329,10 +6545,12 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
+ def arrowMaxRecordsPerOutputBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_OUTPUT_BATCH)
+
def arrowMaxBytesPerBatch: Long = getConf(ARROW_EXECUTION_MAX_BYTES_PER_BATCH)
- def arrowTransformWithStateInPandasMaxRecordsPerBatch: Int =
- getConf(ARROW_TRANSFORM_WITH_STATE_IN_PANDAS_MAX_RECORDS_PER_BATCH)
+ def arrowTransformWithStateInPySparkMaxStateRecordsPerBatch: Int =
+ getConf(ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH)
def arrowUseLargeVarTypes: Boolean = getConf(ARROW_EXECUTION_USE_LARGE_VAR_TYPES)
@@ -6344,6 +6562,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK)
+ def pysparkArrowValidateSchema: Boolean = getConf(PYSPARK_ARROW_VALIDATE_SCHEMA)
+
def pandasGroupedMapAssignColumnsByName: Boolean =
getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME)
@@ -6382,10 +6602,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def validatePartitionColumns: Boolean = getConf(VALIDATE_PARTITION_COLUMNS)
def partitionOverwriteMode: PartitionOverwriteMode.Value =
- PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE))
+ getConf(PARTITION_OVERWRITE_MODE)
def storeAssignmentPolicy: StoreAssignmentPolicy.Value =
- StoreAssignmentPolicy.withName(getConf(STORE_ASSIGNMENT_POLICY))
+ getConf(STORE_ASSIGNMENT_POLICY)
override def ansiEnabled: Boolean = getConf(ANSI_ENABLED)
@@ -6408,11 +6628,11 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def chunkBase64StringEnabled: Boolean = getConf(CHUNK_BASE64_STRING_ENABLED)
def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match {
- case "TIMESTAMP_LTZ" =>
+ case TimestampTypes.TIMESTAMP_LTZ =>
// For historical reason, the TimestampType maps to TIMESTAMP WITH LOCAL TIME ZONE
TimestampType
- case "TIMESTAMP_NTZ" =>
+ case TimestampTypes.TIMESTAMP_NTZ =>
TimestampNTZType
}
@@ -6477,6 +6697,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def useListFilesFileSystemList: String = getConf(SQLConf.USE_LISTFILES_FILESYSTEM_LIST)
+ def pythonFilterPushDown: Boolean = getConf(PYTHON_FILTER_PUSHDOWN_ENABLED)
+
def csvFilterPushDown: Boolean = getConf(CSV_FILTER_PUSHDOWN_ENABLED)
def jsonFilterPushDown: Boolean = getConf(JSON_FILTER_PUSHDOWN_ENABLED)
@@ -6560,8 +6782,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def histogramNumericPropagateInputType: Boolean =
getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE)
- def errorMessageFormat: ErrorMessageFormat.Value =
- ErrorMessageFormat.withName(getConf(SQLConf.ERROR_MESSAGE_FORMAT))
+ def errorMessageFormat: ErrorMessageFormat.Value = getConf(SQLConf.ERROR_MESSAGE_FORMAT)
def defaultDatabase: String = getConf(StaticSQLConf.CATALOG_DEFAULT_DATABASE)
@@ -6593,6 +6814,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME)
+ def legacyOutputSchema: Boolean = getConf(SQLConf.LEGACY_KEEP_COMMAND_OUTPUT_SCHEMA)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ProcedureParameterImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ProcedureParameterImpl.scala
index 01ea48af1537c..ede9c0915ef24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ProcedureParameterImpl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ProcedureParameterImpl.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.internal.connector
+import org.apache.spark.sql.connector.catalog.DefaultValue
import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter
import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode
import org.apache.spark.sql.types.DataType
@@ -25,5 +26,5 @@ case class ProcedureParameterImpl(
mode: Mode,
name: String,
dataType: DataType,
- defaultValueExpression: String,
+ defaultValue: DefaultValue,
comment: String) extends ProcedureParameter
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 088f0e21710d2..2f92fe3d083d3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
import java.math.MathContext
import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneId}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period, ZoneId}
import java.time.temporal.ChronoUnit
import scala.collection.mutable
@@ -284,6 +284,18 @@ object RandomDataGenerator {
},
specialTs.map { s => LocalDateTime.parse(s.replace(" ", "T")) }
)
+ case _: TimeType =>
+ val specialTimes = Seq(
+ "00:00:00",
+ "23:59:59.999999"
+ )
+ randomNumeric[LocalTime](
+ rand,
+ (rand: Random) => {
+ DateTimeUtils.microsToLocalTime(rand.between(0, 24 * 60 * 60 * 1000 * 1000L))
+ },
+ specialTimes.map(LocalTime.parse)
+ )
case CalendarIntervalType => Some(() => {
val months = rand.nextInt(1000)
val days = rand.nextInt(10000)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index 7572843f44a19..3457a9ced4e31 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period}
import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.Row
@@ -415,4 +415,33 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {
}
}
}
+
+ test("converting java.time.LocalTime to TimeType") {
+ Seq(
+ "00:00:00",
+ "01:02:03.999",
+ "02:59:01",
+ "12:30:02.0",
+ "22:00:00.000001",
+ "23:59:59.999999").foreach { time =>
+ val input = LocalTime.parse(time)
+ val result = CatalystTypeConverters.convertToCatalyst(input)
+ val expected = DateTimeUtils.localTimeToMicros(input)
+ assert(result === expected)
+ }
+ }
+
+ test("converting TimeType to java.time.LocalTime") {
+ Seq(
+ 0,
+ 1,
+ 59000000,
+ 3600000001L,
+ 43200999999L,
+ 86399000000L,
+ 86399999999L).foreach { us =>
+ val localTime = DateTimeUtils.microsToLocalTime(us)
+ assert(CatalystTypeConverters.createToScalaConverter(TimeType())(us) === localTime)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index eab4ddc666be4..f0dabfd976a7c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -819,6 +819,25 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}
}
+ test("CURRENT_TIME should be case insensitive") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val input = Project(Seq(
+ // The user references "current_time" or "CURRENT_TIME" in the query
+ UnresolvedAttribute("current_time"),
+ UnresolvedAttribute("CURRENT_TIME")
+ ), testRelation)
+
+ // The analyzer should resolve both to the same expression: CurrentTime()
+ val expected = Project(Seq(
+ Alias(CurrentTime(), toPrettySQL(CurrentTime()))(),
+ Alias(CurrentTime(), toPrettySQL(CurrentTime()))()
+ ), testRelation).analyze
+
+ checkAnalysis(input, expected)
+ }
+ }
+
+
test("CTE with non-existing column alias") {
assertAnalysisErrorCondition(parsePlan("WITH t(x) AS (SELECT 1) SELECT * FROM t WHERE y = 1"),
"UNRESOLVED_COLUMN.WITH_SUGGESTION",
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala
index 133670d5fcced..0afdffb8b5e7c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala
@@ -30,7 +30,8 @@ import org.apache.spark.util.ArrayImplicits._
class CreateTablePartitioningValidationSuite extends AnalysisTest {
val tableSpec =
- UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, None, false)
+ UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, None, false,
+ Seq.empty)
test("CreateTableAsSelect: fail missing top-level column") {
val plan = CreateTableAsSelect(
UnresolvedIdentifier(Array("table_name").toImmutableArraySeq),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/GroupByOrdinalsRepeatedAnalysisSuite.scala
similarity index 59%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/GroupByOrdinalsRepeatedAnalysisSuite.scala
index 39cf298aec434..ac120e80d51c1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/GroupByOrdinalsRepeatedAnalysisSuite.scala
@@ -17,63 +17,42 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.analysis.TestRelations.{testRelation, testRelation2}
+import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.internal.SQLConf
-class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest {
- private lazy val a = testRelation2.output(0)
- private lazy val b = testRelation2.output(1)
+class GroupByOrdinalsRepeatedAnalysisSuite extends AnalysisTest {
test("unresolved ordinal should not be unresolved") {
// Expression OrderByOrdinal is unresolved.
assert(!UnresolvedOrdinal(0).resolved)
}
- test("order by ordinal") {
- // Tests order by ordinal, apply single rule.
- val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc)
+ test("SPARK-45920: group by ordinal repeated analysis") {
+ val plan = testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze
comparePlans(
- SubstituteUnresolvedOrdinals.apply(plan),
- testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc))
-
- // Tests order by ordinal, do full analysis
- checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc))
+ plan,
+ testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze
+ )
- // order by ordinal can be turned off by config
- withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") {
- comparePlans(
- SubstituteUnresolvedOrdinals.apply(plan),
- testRelation2.orderBy(Literal(1).asc, Literal(2).asc))
+ val testRelationWithData = testRelation.copy(data = Seq(new GenericInternalRow(Array(1: Any))))
+ // Copy the plan to reset its `analyzed` flag, so that analyzer rules will re-apply.
+ val copiedPlan = plan.transform {
+ case _: LocalRelation => testRelationWithData
}
- }
-
- test("group by ordinal") {
- // Tests group by ordinal, apply single rule.
- val plan2 = testRelation2.groupBy(Literal(1), Literal(2))($"a", $"b")
comparePlans(
- SubstituteUnresolvedOrdinals.apply(plan2),
- testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))($"a", $"b"))
-
- // Tests group by ordinal, do full analysis
- checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b))
-
- // group by ordinal can be turned off by config
- withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") {
- comparePlans(
- SubstituteUnresolvedOrdinals.apply(plan2),
- testRelation2.groupBy(Literal(1), Literal(2))($"a", $"b"))
- }
+ copiedPlan.analyze, // repeated analysis
+ testRelationWithData.groupBy(Literal(1))(Literal(100).as("a")).analyze
+ )
}
- test("SPARK-45920: group by ordinal repeated analysis") {
- val plan = testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze
+ test("SPARK-47895: group by all repeated analysis") {
+ val plan = testRelation.groupBy($"all")(Literal(100).as("a")).analyze
comparePlans(
plan,
- testRelation.groupBy(Literal(1))(Literal(100).as("a"))
+ testRelation.groupBy(Literal(1))(Literal(100).as("a")).analyze
)
val testRelationWithData = testRelation.copy(data = Seq(new GenericInternalRow(Array(1: Any))))
@@ -83,15 +62,15 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest {
}
comparePlans(
copiedPlan.analyze, // repeated analysis
- testRelationWithData.groupBy(Literal(1))(Literal(100).as("a"))
+ testRelationWithData.groupBy(Literal(1))(Literal(100).as("a")).analyze
)
}
- test("SPARK-47895: group by all repeated analysis") {
- val plan = testRelation.groupBy($"all")(Literal(100).as("a")).analyze
+ test("SPARK-47895: group by alias repeated analysis") {
+ val plan = testRelation.groupBy($"b")(Literal(100).as("b")).analyze
comparePlans(
plan,
- testRelation.groupBy(Literal(1))(Literal(100).as("a"))
+ testRelation.groupBy(Literal(1))(Literal(100).as("b")).analyze
)
val testRelationWithData = testRelation.copy(data = Seq(new GenericInternalRow(Array(1: Any))))
@@ -101,7 +80,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest {
}
comparePlans(
copiedPlan.analyze, // repeated analysis
- testRelationWithData.groupBy(Literal(1))(Literal(100).as("a"))
+ testRelationWithData.groupBy(Literal(1))(Literal(100).as("b")).analyze
)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
index f231164d5c25a..662c740e2f96c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
@@ -21,11 +21,11 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.EvaluateUnresolvedInlineTable
-import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CurrentTimestamp, Literal, Rand}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CurrentTime, CurrentTimestamp, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.optimizer.{ComputeCurrentTime, EvalInlineTables}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{LongType, NullType, TimestampType}
+import org.apache.spark.sql.types.{LongType, NullType, TimestampType, TimeType}
/**
* Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in
@@ -113,6 +113,32 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
}
}
+ test("cast and execute CURRENT_TIME expressions") {
+ val table = UnresolvedInlineTable(
+ Seq("c1"),
+ Seq(
+ Seq(CurrentTime()),
+ Seq(CurrentTime())
+ )
+ )
+ val resolved = ResolveInlineTables(table)
+ assert(resolved.isInstanceOf[ResolvedInlineTable],
+ "Expected an inline table to be resolved into a ResolvedInlineTable")
+
+ val transformed = ComputeCurrentTime(resolved)
+ EvalInlineTables(transformed) match {
+ case LocalRelation(output, data, _, _) =>
+ // expect default precision = 6
+ assert(output.map(_.dataType) == Seq(TimeType(6)))
+ // Should have 2 rows
+ assert(data.size == 2)
+ // Both rows should have the *same* microsecond value for current_time
+ assert(data(0).getLong(0) == data(1).getLong(0),
+ "Both CURRENT_TIME calls must yield the same value in the same query")
+ }
+ }
+
+
test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala
index 29c6c63ecfeab..0b872d61eca3e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala
@@ -333,6 +333,37 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
hasTransform = true)
}
+ test("SPARK-48922: Avoid redundant array transform of identical expression for map type") {
+ def assertMapField(fromType: MapType, toType: MapType, transformNum: Int): Unit = {
+ val table = TestRelation(Seq($"a".int, Symbol("map").map(toType)))
+ val query = TestRelation(Seq(Symbol("map").map(fromType), $"a".int))
+
+ val writePlan = byName(table, query).analyze
+
+ assertResolved(writePlan)
+ checkAnalysis(writePlan, writePlan)
+
+ val transforms = writePlan.children.head.expressions.flatMap { e =>
+ e.flatMap {
+ case t: ArrayTransform => Some(t)
+ case _ => None
+ }
+ }
+ assert(transforms.size == transformNum)
+ }
+
+ assertMapField(MapType(LongType, StringType), MapType(LongType, StringType), 0)
+ assertMapField(
+ MapType(LongType, new StructType().add("x", "int").add("y", "int")),
+ MapType(LongType, new StructType().add("y", "int").add("x", "byte")),
+ 1)
+ assertMapField(MapType(LongType, LongType), MapType(IntegerType, LongType), 1)
+ assertMapField(
+ MapType(LongType, new StructType().add("x", "int").add("y", "int")),
+ MapType(IntegerType, new StructType().add("y", "int").add("x", "byte")),
+ 2)
+ }
+
test("SPARK-33136: output resolved on complex types for V2 write commands") {
def assertTypeCompatibility(name: String, fromType: DataType, toType: DataType): Unit = {
val table = TestRelation(StructType(Seq(StructField("a", toType))))
@@ -420,12 +451,14 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
val parsedPlan = byName(table, query)
- assertNotResolved(parsedPlan)
- assertAnalysisErrorCondition(
- parsedPlan,
- expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
- expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
- )
+ withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") {
+ assertNotResolved(parsedPlan)
+ assertAnalysisErrorCondition(
+ parsedPlan,
+ expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
+ expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
+ )
+ }
}
test("byName: case sensitive column resolution") {
@@ -435,12 +468,14 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
val parsedPlan = byName(table, query)
- assertNotResolved(parsedPlan)
- assertAnalysisErrorCondition(
- parsedPlan,
- expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
- expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
- )
+ withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") {
+ assertNotResolved(parsedPlan)
+ assertAnalysisErrorCondition(
+ parsedPlan,
+ expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
+ expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
+ )
+ }
}
test("byName: case insensitive column resolution") {
@@ -513,12 +548,14 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
val parsedPlan = byName(table, query)
- assertNotResolved(parsedPlan)
- assertAnalysisErrorCondition(
- parsedPlan,
- expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
- expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
- )
+ withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") {
+ assertNotResolved(parsedPlan)
+ assertAnalysisErrorCondition(
+ parsedPlan,
+ expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
+ expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
+ )
+ }
}
test("byName: insert safe cast") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitLikeExpressionValidatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitLikeExpressionValidatorSuite.scala
new file mode 100644
index 0000000000000..de4ba6d66adae
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitLikeExpressionValidatorSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis.resolver
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LocalLimit, Offset, Tail}
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.types.IntegerType
+
+class LimitLikeExpressionValidatorSuite extends SparkFunSuite with QueryErrorsBase {
+
+ private val limitLikeExpressionValidator = new LimitLikeExpressionValidator
+ private val testCases = Seq(
+ (LocalLimit(_, null), "localLimit", "limit"),
+ (GlobalLimit(_, null), "globalLimit", "limit"),
+ (Offset(_, null), "offset", "offset"),
+ (Tail(_, null), "tail", "tail")
+ )
+
+ for ((planBuilder, name, simpleName) <- testCases) {
+ test(s"Basic $name without errors") {
+ val expr = Literal(42, IntegerType)
+ val plan = planBuilder(expr)
+ assert(limitLikeExpressionValidator.validateLimitLikeExpr(expr, plan) == expr)
+ }
+
+ test(s"Unfoldable $name") {
+ val col = AttributeReference(name = "foo", dataType = IntegerType)()
+ val plan = planBuilder(col)
+ checkError(
+ exception = intercept[AnalysisException] {
+ limitLikeExpressionValidator.validateLimitLikeExpr(col, plan)
+ },
+ condition = "INVALID_LIMIT_LIKE_EXPRESSION.IS_UNFOLDABLE",
+ parameters = Map("name" -> simpleName, "expr" -> toSQLExpr(col))
+ )
+ }
+
+ test(s"$name with non-integer") {
+ val anyNonInteger = Literal("42")
+ val plan = planBuilder(anyNonInteger)
+ checkError(
+ exception = intercept[AnalysisException] {
+ limitLikeExpressionValidator.validateLimitLikeExpr(anyNonInteger, plan)
+ },
+ condition = "INVALID_LIMIT_LIKE_EXPRESSION.DATA_TYPE",
+ parameters = Map(
+ "name" -> simpleName,
+ "expr" -> toSQLExpr(anyNonInteger),
+ "dataType" -> toSQLType(anyNonInteger.dataType)
+ )
+ )
+ }
+
+ test(s"$name with null") {
+ val expr = Cast(Literal(null), IntegerType)
+ val plan = planBuilder(expr)
+ checkError(
+ exception = intercept[AnalysisException] {
+ limitLikeExpressionValidator.validateLimitLikeExpr(expr, plan)
+ },
+ condition = "INVALID_LIMIT_LIKE_EXPRESSION.IS_NULL",
+ parameters = Map("name" -> simpleName, "expr" -> toSQLExpr(expr))
+ )
+ }
+
+ test(s"$name with negative integer") {
+ val expr = Literal(-1, IntegerType)
+ val plan = planBuilder(expr)
+ checkError(
+ exception = intercept[AnalysisException] {
+ limitLikeExpressionValidator.validateLimitLikeExpr(expr, plan)
+ },
+ condition = "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE",
+ parameters =
+ Map("name" -> simpleName, "expr" -> toSQLExpr(expr), "v" -> toSQLValue(-1, IntegerType))
+ )
+ }
+ }
+
+ test("LIMIT with OFFSET sum exceeds max int") {
+ val expr = Literal(Int.MaxValue, IntegerType)
+ val plan = LocalLimit(expr, Offset(expr, null))
+ checkError(
+ exception = intercept[AnalysisException] {
+ limitLikeExpressionValidator.validateLimitLikeExpr(expr, plan)
+ },
+ condition = "SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT",
+ parameters = Map("limit" -> Int.MaxValue.toString, "offset" -> Int.MaxValue.toString)
+ )
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index bba784800976c..616c6d65636d7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -23,20 +23,21 @@ import java.util.Arrays
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.reflect.classTag
+import scala.reflect.{classTag, ClassTag}
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.{SPARK_DOC_ROOT, SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException}
import org.apache.spark.sql.{Encoder, Encoders, Row}
-import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData, ScroogeLikeExample}
+import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData, ScalaReflection, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, TransformingEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, EncoderField, IterableEncoder, MapEncoder, OptionEncoder, PrimitiveIntEncoder, ProductEncoder, TimestampEncoder, TransformingEncoder}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NaNvl}
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.{instantToMicros, microsToInstant}
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -142,6 +143,21 @@ case class OptionNestedGeneric[T](list: Option[T])
case class MapNestedGenericKey[T](list: Map[T, Int])
case class MapNestedGenericValue[T](list: Map[Int, T])
+// ADT encoding for TransformingEncoder test
+trait Base {
+ def name: String
+}
+
+case class A(name: String, number: Int) extends Base
+
+case class B(name: String, text: String) extends Base
+
+case class Struct(typ: String, name: String, number: Option[Int] = None,
+ text: Option[String] = None)
+// end ADT encoding
+
+case class V[A](v: A)
+
class Wrapper[T](val value: T) {
override def hashCode(): Int = value.hashCode()
override def equals(obj: Any): Boolean = obj match {
@@ -585,6 +601,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
encodeDecodeTest(FooEnum.E1, "scala Enum")
+ // TransformingEncoder tests ----------------------------------------------------------
private def testTransformingEncoder(
name: String,
@@ -592,7 +609,8 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
val encoder = ExpressionEncoder(TransformingEncoder(
classTag[(Long, Long)],
BinaryEncoder,
- provider))
+ provider,
+ nullable = true))
.resolveAndBind()
assert(encoder.schema == new StructType().add("value", BinaryType))
val toRow = encoder.createSerializer()
@@ -603,6 +621,32 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
testTransformingEncoder("transforming java serialization encoder", JavaSerializationCodec)
testTransformingEncoder("transforming kryo encoder", KryoSerializationCodec)
+ test("transforming encoders ADT - Frameless Injections use case") {
+ val provider = () => new Codec[Base, Struct]{
+ override def encode(in: Base): Struct = in match {
+ case A(name, number) => Struct("A", name, number = Some(number))
+ case B(name, text) => Struct("B", name, text = Some(text))
+ }
+
+ override def decode(out: Struct): Base = out match {
+ case Struct("A", name, Some(number), None) => A(name, number)
+ case Struct("B", name, None, Some(text)) => B(name, text)
+ case _ => throw new Exception(f"Invalid Base structure {s}")
+ }
+ }
+ val encoder = ExpressionEncoder(TransformingEncoder(
+ classTag[Base],
+ ScalaReflection.encoderFor[Struct],
+ provider))
+ .resolveAndBind()
+
+ val toRow = encoder.createSerializer()
+ val fromRow = encoder.createDeserializer()
+
+ assert(fromRow(toRow(A("anA", 1))) == A("anA", 1))
+ assert(fromRow(toRow(B("aB", "text"))) == B("aB", "text"))
+ }
+
test("transforming row encoder") {
val schema = new StructType().add("a", LongType).add("b", StringType)
val encoder = ExpressionEncoder(TransformingEncoder(
@@ -615,7 +659,141 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
assert(fromRow(toRow(new Wrapper(Row(9L, "x")))) == new Wrapper(Row(9L, "x")))
}
+ // below tests are related to SPARK-49960 and TransformingEncoder usage
+ test("""Encoder with OptionEncoder of transformation""".stripMargin) {
+ type T = Option[V[V[Int]]]
+ val data: Seq[T] = Seq(None, Some(V(V(1))))
+
+ /* attempt to behave as if value class semantics except the last product,
+ using a final transforming instead of a product serializes */
+ val enc =
+ OptionEncoder(
+ transforming(
+ V_OF_INT,
+ true
+ )
+ )
+
+ testDataTransformingEnc(enc, data)
+ }
+ def testDataTransformingEnc[T](enc: AgnosticEncoder[T], data: Seq[T]): Unit = {
+ val encoder = ExpressionEncoder[T](enc).resolveAndBind()
+ val toRow = encoder.createSerializer()
+ val fromRow = encoder.createDeserializer()
+ data.foreach{ row =>
+ assert(fromRow(toRow(data.head)) === data.head)
+ }
+ }
+
+ def provider[A]: () => Codec[V[A], A] = () =>
+ new Codec[V[A], A]{
+ override def encode(in: V[A]): A = in.v
+ override def decode(out: A): V[A] = if (out == null) null else V(out)
+ }
+
+ def transforming[A](underlying: AgnosticEncoder[A],
+ useUnderyling: Boolean = false): TransformingEncoder[V[A], A] =
+ TransformingEncoder[V[A], A](
+ implicitly[ClassTag[V[A]]],
+ underlying,
+ provider,
+ if (useUnderyling) {
+ underlying.nullable
+ } else {
+ false
+ }
+ )
+
+ val V_INT = StructType(Seq(StructField("v", IntegerType, nullable = true)))
+
+ // product encoder for a non-nullable V
+ val V_OF_INT =
+ ProductEncoder(
+ classTag[V[Int]],
+ Seq(EncoderField("v", PrimitiveIntEncoder, nullable = false, Metadata.empty)),
+ None
+ )
+
+ test("""Encoder derivation with nested TransformingEncoder of OptionEncoder""".stripMargin) {
+ type T = V[V[Option[V[Int]]]]
+ val data: Seq[T] = Seq(V(V(None)), V(V(Some(V(1)))))
+
+ /* attempt to behave as if value class semantics except the last product,
+ using a final transforming instead of a product serializes */
+ val enc =
+ transforming(
+ transforming(
+ OptionEncoder(
+ V_OF_INT
+ )
+ )
+ )
+
+ testDataTransformingEnc(enc, data)
+ }
+
+ test("""Encoder derivation with TransformingEncoder of OptionEncoder""".stripMargin) {
+ type T = V[Option[V[Int]]]
+ val data: Seq[T] = Seq(V(None), V(Some(V(1))))
+
+ /* attempt to behave as if value class semantics except the last product,
+ using a final transforming instead of a product serializes */
+ val enc =
+ transforming(
+ OptionEncoder(
+ V_OF_INT
+ )
+ )
+
+ testDataTransformingEnc(enc, data)
+ }
+
+ val longEncForTimestamp: AgnosticEncoder[V[Long]] =
+ TransformingEncoder[V[Long], java.sql.Timestamp](
+ classTag,
+ TimestampEncoder(true),
+ () =>
+ new Codec[V[Long], java.sql.Timestamp] with Serializable {
+ override def encode(in: V[Long]): Timestamp = Timestamp.from(microsToInstant(in.v))
+
+ override def decode(out: Timestamp): V[Long] = V[Long](instantToMicros(out.toInstant))
+ }
+ )
+
+ test("""TransformingEncoder as Iterable""".stripMargin) {
+ type T = Seq[V[Long]]
+ val data: Seq[T] = Seq(Seq(V(0)), Seq(V(1), V(2)))
+
+ /* requires validateAndSerializeElement to test for TransformingEncoder */
+ val enc: AgnosticEncoder[T] =
+ IterableEncoder[Seq[V[Long]], V[Long]](
+ implicitly[ClassTag[Seq[V[Long]]]],
+ longEncForTimestamp,
+ containsNull = false,
+ lenientSerialization = false)
+
+ assert(enc.dataType === new ArrayType(TimestampType, false))
+
+ testDataTransformingEnc(enc, data)
+ }
+
+ test("""TransformingEncoder as Map Key/Value""".stripMargin) {
+ type T = Map[V[Long], V[Long]]
+ val data: Seq[T] = Seq(Map(V(0L) -> V(0L)), Map(V(1L) -> V(1L)), Map(V(2L) -> V(2L)))
+
+ /* requires validateAndSerializeElement to test for TransformingEncoder */
+ val enc: AgnosticEncoder[T] =
+ MapEncoder[T, V[Long], V[Long]](
+ implicitly[ClassTag[T]],
+ longEncForTimestamp,
+ longEncForTimestamp,
+ valueContainsNull = false)
+
+ assert(enc.dataType === new MapType(TimestampType, TimestampType, false))
+
+ testDataTransformingEnc(enc, data)
+ }
// Scala / Java big decimals ----------------------------------------------------------
encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 645b80ffaacb8..1609e1a4e1136 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -372,6 +372,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}
+ test("encoding/decoding TimeType to/from java.time.LocalTime") {
+ val schema = new StructType().add("t", TimeType())
+ val encoder = ExpressionEncoder(schema).resolveAndBind()
+ val localTime = java.time.LocalTime.parse("20:38:45.123456")
+ val row = toRow(encoder, Row(localTime))
+ assert(row.getLong(0) === DateTimeUtils.localTimeToMicros(localTime))
+ val readback = fromRow(encoder, row)
+ assert(readback.get(0).equals(localTime))
+ }
+
test("SPARK-34605: encoding/decoding DayTimeIntervalType to/from java.time.Duration") {
dayTimeIntervalTypes.foreach { dayTimeIntervalType =>
val schema = new StructType().add("d", dayTimeIntervalType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index cec49a5ae1de0..92642de94a43c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
-import java.time.{Duration, LocalDate, LocalDateTime, Period}
+import java.time.{Duration, LocalDate, LocalDateTime, LocalTime, Period}
import java.time.temporal.ChronoUnit
import java.util.{Calendar, Locale, TimeZone}
@@ -82,7 +82,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
atomicTypes.foreach(dt => checkNullCast(NullType, dt))
- atomicTypes.foreach(dt => checkNullCast(dt, StringType))
+ (atomicTypes ++ timeTypes).foreach(dt => checkNullCast(dt, StringType))
checkNullCast(StringType, BinaryType)
checkNullCast(StringType, BooleanType)
numericTypes.foreach(dt => checkNullCast(dt, BooleanType))
@@ -1457,4 +1457,31 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
+
+ test("cast time to string") {
+ Seq(
+ LocalTime.MIDNIGHT -> "00:00:00",
+ LocalTime.NOON -> "12:00:00",
+ LocalTime.of(23, 59, 59) -> "23:59:59",
+ LocalTime.of(23, 59, 59, 1000000) -> "23:59:59.001",
+ LocalTime.of(23, 59, 59, 999999000) -> "23:59:59.999999"
+ ).foreach { case (time, expectedStr) =>
+ checkEvaluation(Cast(Literal(time), StringType), expectedStr)
+ }
+
+ checkConsistencyBetweenInterpretedAndCodegen(
+ (child: Expression) => Cast(child, StringType), TimeType())
+ }
+
+ test("cast string to time") {
+ checkEvaluation(cast(Literal.create("0:0:0"), TimeType()), 0L)
+ checkEvaluation(cast(Literal.create(" 01:2:3.01 "), TimeType(2)), localTime(1, 2, 3, 10000))
+ checkEvaluation(cast(Literal.create(" 12:13:14.999"),
+ TimeType(3)), localTime(12, 13, 14, 999 * 1000))
+ checkEvaluation(cast(Literal.create("23:0:59.0001 "), TimeType(4)), localTime(23, 0, 59, 100))
+ checkEvaluation(cast(Literal.create("23:59:0.99999"),
+ TimeType(5)), localTime(23, 59, 0, 999990))
+ checkEvaluation(cast(Literal.create("23:59:59.000001 "),
+ TimeType(6)), localTime(23, 59, 59, 1))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
index 141eaf56fffb7..534563a79742f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala
@@ -901,4 +901,10 @@ class CastWithAnsiOffSuite extends CastSuiteBase {
castOverflowErrMsg(toType))
}
}
+
+ test("cast invalid string input to time") {
+ Seq("a", "123", "00:00:00ABC", "24:00:00").foreach { invalidInput =>
+ checkEvaluation(cast(invalidInput, TimeType()), null)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
index 674d306dbabb4..b62b8c1302ccf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
@@ -23,13 +23,14 @@ import java.time.DateTimeException
import org.apache.spark.{SparkArithmeticException, SparkRuntimeException}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
/**
* Test suite for data type casting expression [[Cast]] with ANSI mode enabled.
@@ -182,6 +183,57 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase {
}
}
+ test("ANSI mode: disallow variant cast to non-nullable types") {
+ // Array
+ val variantVal = new VariantVal(Array[Byte](12, 3), Array[Byte](1, 0, 0))
+ val sourceArrayType = ArrayType(VariantType, containsNull = false)
+ val targetArrayType = ArrayType(StringType, containsNull = false)
+ val variantArray = Literal.create(Seq(variantVal), sourceArrayType)
+ assert(cast(variantArray, targetArrayType).checkInputDataTypes() == DataTypeMismatch(
+ errorSubClass = "CAST_WITHOUT_SUGGESTION",
+ messageParameters = Map(
+ "srcType" -> toSQLType(sourceArrayType),
+ "targetType" -> toSQLType(targetArrayType)
+ )
+ ))
+ // make sure containsNull = true works
+ val targetArrayType2 = ArrayType(StringType, containsNull = true)
+ assert(cast(variantArray, targetArrayType2).checkInputDataTypes() ==
+ TypeCheckResult.TypeCheckSuccess)
+
+ // Struct
+ val sourceStructType = StructType(Array(StructField("v", VariantType, nullable = false)))
+ val targetStructType = StructType(Array(StructField("v", StringType, nullable = false)))
+ val variantStruct = Literal.create(Row(variantVal), sourceStructType)
+ assert(cast(variantStruct, targetStructType).checkInputDataTypes() == DataTypeMismatch(
+ errorSubClass = "CAST_WITHOUT_SUGGESTION",
+ messageParameters = Map(
+ "srcType" -> toSQLType(sourceStructType),
+ "targetType" -> toSQLType(targetStructType)
+ )
+ ))
+ // make sure nullable = true works
+ val targetStructType2 = StructType(Array(StructField("v", StringType, nullable = true)))
+ assert(cast(variantStruct, targetStructType2).checkInputDataTypes() ==
+ TypeCheckResult.TypeCheckSuccess)
+
+ // Map
+ val sourceMapType = MapType(StringType, VariantType, valueContainsNull = false)
+ val targetMapType = MapType(StringType, StringType, valueContainsNull = false)
+ val variantMap = Literal.create(Map("k" -> variantVal), sourceMapType)
+ assert(cast(variantMap, targetMapType).checkInputDataTypes() == DataTypeMismatch(
+ errorSubClass = "CAST_WITHOUT_SUGGESTION",
+ messageParameters = Map(
+ "srcType" -> toSQLType(sourceMapType),
+ "targetType" -> toSQLType(targetMapType)
+ )
+ ))
+ // make sure valueContainsNull = true works
+ val targetMapType2 = MapType(StringType, StringType, valueContainsNull = true)
+ assert(cast(variantMap, targetMapType2).checkInputDataTypes() ==
+ TypeCheckResult.TypeCheckSuccess)
+ }
+
test("ANSI mode: disallow type conversions between Datatime types and Boolean types") {
val timestampLiteral = Literal(1L, TimestampType)
val checkResult1 = cast(timestampLiteral, BooleanType).checkInputDataTypes()
@@ -737,4 +789,12 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase {
val input = Literal.create(Decimal(0.000000123), DecimalType(9, 9))
checkEvaluation(cast(input, StringType), "0.000000123")
}
+
+ test("cast invalid string input to time") {
+ Seq("a", "123", "00:00:00ABC", "24:00:00").foreach { invalidInput =>
+ checkExceptionInExpression[DateTimeException](
+ cast(invalidInput, TimeType()),
+ castErrMsg(invalidInput, TimeType()))
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 4c045f9fda731..7ce14bcedf4ba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.LA
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ThreadUtils
@@ -534,6 +535,21 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
.exists(_.getMessage().getFormattedMessage.contains("Generated method too long")))
}
+ test("SPARK-51527: spark.sql.codegen.logLevel") {
+ withSQLConf(SQLConf.CODEGEN_LOG_LEVEL.key -> "INFO") {
+ val appender = new LogAppender("codegen log level")
+ withLogAppender(appender, loggerNames = Seq(classOf[CodeGenerator[_, _]].getName),
+ Some(Level.INFO)) {
+ GenerateUnsafeProjection.generate(Seq(Literal.TrueLiteral))
+ }
+ assert(appender.loggingEvents.exists { event =>
+ event.getLevel === Level.INFO &&
+ event.getMessage.getFormattedMessage.contains(
+ "public java.lang.Object generate(Object[] references)")
+ })
+ }
+ }
+
test("SPARK-28916: subexpression elimination can cause 64kb code limit on UnsafeProjection") {
val numOfExprs = 10000
val exprs = (0 to numOfExprs).flatMap(colIndex =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 019c953a3b0ac..dddc33aa43580 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
-import java.time.{Duration, Period, ZoneId, ZoneOffset}
+import java.time.{Duration, LocalTime, Period, ZoneId, ZoneOffset}
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
@@ -754,6 +754,13 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkResult(Literal.create(-0F, FloatType), Literal.create(0F, FloatType))
}
+ test("Support TimeType") {
+ val time = Literal.create(LocalTime.of(23, 50, 59, 123456000), TimeType())
+ checkEvaluation(Murmur3Hash(Seq(time), 10), 258472763)
+ checkEvaluation(XxHash64(Seq(time), 10), -9197489935839400467L)
+ checkEvaluation(HiveHash(Seq(time)), -40222445)
+ }
+
private def testHash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
val toRow = ExpressionEncoder(inputSchema).createSerializer()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 5da5c6ac412cc..9ed0b48680c62 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffset}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period, ZoneOffset}
import java.time.temporal.ChronoUnit
import java.util.TimeZone
@@ -30,6 +30,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.localTime
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -51,6 +52,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, BinaryType), null)
checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null)
checkEvaluation(Literal.create(null, DateType), null)
+ checkEvaluation(Literal.create(null, TimeType()), null)
checkEvaluation(Literal.create(null, TimestampType), null)
checkEvaluation(Literal.create(null, CalendarIntervalType), null)
checkEvaluation(Literal.create(null, YearMonthIntervalType()), null)
@@ -81,6 +83,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.default(DateType), LocalDate.ofEpochDay(0))
checkEvaluation(Literal.default(TimestampType), Instant.ofEpochSecond(0))
}
+ checkEvaluation(Literal.default(TimeType()), LocalTime.MIDNIGHT)
checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0, 0L))
checkEvaluation(Literal.default(YearMonthIntervalType()), 0)
checkEvaluation(Literal.default(DayTimeIntervalType()), 0L)
@@ -313,6 +316,13 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ test("construct literals from arrays of java.time.LocalTime") {
+ val localTime0 = LocalTime.of(1, 2, 3)
+ checkEvaluation(Literal(Array(localTime0)), Array(localTime0))
+ val localTime1 = LocalTime.of(23, 59, 59, 999999000)
+ checkEvaluation(Literal(Array(localTime0, localTime1)), Array(localTime0, localTime1))
+ }
+
test("construct literals from java.time.Instant") {
Seq(
Instant.parse("0001-01-01T00:00:00Z"),
@@ -497,6 +507,11 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
}
checkEvaluation(Literal.create(duration, dt), result)
}
+
+ val time = LocalTime.of(12, 13, 14)
+ DataTypeTestUtils.timeTypes.foreach { tt =>
+ checkEvaluation(Literal.create(time, tt), localTime(12, 13, 14))
+ }
}
test("SPARK-37967: Literal.create support ObjectType") {
@@ -531,4 +546,17 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(immArraySeq), expected)
checkEvaluation(Literal.create(immArraySeq, ArrayType(DoubleType)), expected)
}
+
+ test("TimeType toString and sql") {
+ Seq(
+ Literal.default(TimeType()) -> "00:00:00",
+ Literal(LocalTime.NOON) -> "12:00:00",
+ Literal(LocalTime.of(23, 59, 59, 100 * 1000 * 1000)) -> "23:59:59.1",
+ Literal(LocalTime.of(23, 59, 59, 10000)) -> "23:59:59.00001",
+ Literal(LocalTime.of(23, 59, 59, 999999000)) -> "23:59:59.999999"
+ ).foreach { case (lit, str) =>
+ assert(lit.toString === str)
+ assert(lit.sql === s"TIME '$str'")
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
index 5c576d3de1b33..ed5843478c009 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate, Period}
+import java.time.{Duration, Instant, LocalDate, LocalTime, Period}
import java.util.concurrent.TimeUnit
import org.scalacheck.{Arbitrary, Gen}
@@ -123,6 +123,14 @@ object LiteralGenerator {
yield Literal.create(new Date(day * MILLIS_PER_DAY), DateType)
}
+ lazy val timeLiteralGen: Gen[Literal] = {
+ // Valid range for TimeType is [00:00:00, 23:59:59.999999]
+ val minTime = DateTimeUtils.localTimeToMicros(LocalTime.MIN)
+ val maxTime = DateTimeUtils.localTimeToMicros(LocalTime.MAX)
+ for { t <- Gen.choose(minTime, maxTime) }
+ yield Literal(t, TimeType())
+ }
+
private def millisGen = {
// Catalyst's Timestamp type stores number of microseconds since epoch in
// a variable of Long type. To prevent arithmetic overflow of Long on
@@ -196,6 +204,7 @@ object LiteralGenerator {
case DoubleType => doubleLiteralGen
case FloatType => floatLiteralGen
case DateType => dateLiteralGen
+ case _: TimeType => timeLiteralGen
case TimestampType => timestampLiteralGen
case TimestampNTZType => timestampNTZLiteralGen
case BooleanType => booleanLiteralGen
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeExpressionsSuite.scala
new file mode 100644
index 0000000000000..06ea49f0f71d0
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeExpressionsSuite.scala
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.time.LocalTime
+
+import org.apache.spark.{SPARK_DOC_ROOT, SparkDateTimeException, SparkFunSuite}
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLValue}
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
+import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StringType, TimeType}
+
+class TimeExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+ test("ParseToTime") {
+ checkEvaluation(new ToTime(Literal("00:00:00"), Literal.create(null)), null)
+ checkEvaluation(new ToTime(Literal("00:00:00"), NonFoldableLiteral(null, StringType)), null)
+ checkEvaluation(new ToTime(Literal(null, StringType), Literal("HH:mm:ss")), null)
+
+ checkEvaluation(new ToTime(Literal("00:00:00")), localTime())
+ checkEvaluation(new ToTime(Literal("23-59-00.000999"), Literal("HH-mm-ss.SSSSSS")),
+ localTime(23, 59, 0, 999))
+ checkEvaluation(
+ new ToTime(Literal("12.00.59.90909"), NonFoldableLiteral("HH.mm.ss.SSSSS")),
+ localTime(12, 0, 59, 909090))
+ checkEvaluation(
+ new ToTime(NonFoldableLiteral(" 12:00.909 "), Literal(" HH:mm.SSS ")),
+ localTime(12, 0, 0, 909000))
+ checkEvaluation(
+ new ToTime(
+ NonFoldableLiteral("12 hours 123 millis"),
+ NonFoldableLiteral("HH 'hours' SSS 'millis'")),
+ localTime(12, 0, 0, 123000))
+
+ checkErrorInExpression[SparkDateTimeException](
+ expression = new ToTime(Literal("100:50")),
+ condition = "CANNOT_PARSE_TIME",
+ parameters = Map("input" -> "'100:50'", "format" -> "'HH:mm:ss.SSSSSS'"))
+ checkErrorInExpression[SparkDateTimeException](
+ expression = new ToTime(Literal("100:50"), Literal("mm:HH")),
+ condition = "CANNOT_PARSE_TIME",
+ parameters = Map("input" -> "'100:50'", "format" -> "'mm:HH'"))
+ }
+
+ test("HourExpressionBuilder") {
+ // Empty expressions list
+ checkError(
+ exception = intercept[AnalysisException] {
+ HourExpressionBuilder.build("hour", Seq.empty)
+ },
+ condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+ parameters = Map(
+ "functionName" -> "`hour`",
+ "expectedNum" -> "> 0",
+ "actualNum" -> "0",
+ "docroot" -> SPARK_DOC_ROOT)
+ )
+
+ // test TIME-typed child should build HoursOfTime
+ val timeExpr = Literal(localTime(12, 58, 59), TimeType())
+ val builtExprForTime = HourExpressionBuilder.build("hour", Seq(timeExpr))
+ assert(builtExprForTime.isInstanceOf[HoursOfTime])
+ assert(builtExprForTime.asInstanceOf[HoursOfTime].child eq timeExpr)
+
+ assert(builtExprForTime.checkInputDataTypes().isSuccess)
+
+ // test TIME-typed child should build HoursOfTime for all allowed custom precision values
+ (TimeType.MIN_PRECISION to TimeType.MICROS_PRECISION).foreach { precision =>
+ val timeExpr = Literal(localTime(12, 58, 59), TimeType(precision))
+ val builtExpr = HourExpressionBuilder.build("hour", Seq(timeExpr))
+
+ assert(builtExpr.isInstanceOf[HoursOfTime])
+ assert(builtExpr.asInstanceOf[HoursOfTime].child eq timeExpr)
+ assert(builtExpr.checkInputDataTypes().isSuccess)
+ }
+
+ // test non TIME-typed child should build hour
+ val tsExpr = Literal("2007-09-03 10:45:23")
+ val builtExprForTs = HourExpressionBuilder.build("hour", Seq(tsExpr))
+ assert(builtExprForTs.isInstanceOf[Hour])
+ assert(builtExprForTs.asInstanceOf[Hour].child eq tsExpr)
+ }
+
+ test("Hour with TIME type") {
+ // A few test times in microseconds since midnight:
+ // time in microseconds -> expected hour
+ val testTimes = Seq(
+ localTime() -> 0,
+ localTime(1) -> 1,
+ localTime(0, 59) -> 0,
+ localTime(14, 30) -> 14,
+ localTime(12, 58, 59) -> 12,
+ localTime(23, 0, 1) -> 23,
+ localTime(23, 59, 59, 999999) -> 23
+ )
+
+ // Create a literal with TimeType() for each test microsecond value
+ // evaluate HoursOfTime(...), and check that the result matches the expected hour.
+ testTimes.foreach { case (micros, expectedHour) =>
+ checkEvaluation(
+ HoursOfTime(Literal(micros, TimeType())),
+ expectedHour)
+ }
+
+ // Verify NULL handling
+ checkEvaluation(
+ HoursOfTime(Literal.create(null, TimeType(TimeType.MICROS_PRECISION))),
+ null
+ )
+
+ checkConsistencyBetweenInterpretedAndCodegen(
+ (child: Expression) => HoursOfTime(child).replacement, TimeType())
+ }
+
+ test("MinuteExpressionBuilder") {
+ // Empty expressions list
+ checkError(
+ exception = intercept[AnalysisException] {
+ MinuteExpressionBuilder.build("minute", Seq.empty)
+ },
+ condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+ parameters = Map(
+ "functionName" -> "`minute`",
+ "expectedNum" -> "> 0",
+ "actualNum" -> "0",
+ "docroot" -> SPARK_DOC_ROOT)
+ )
+
+ // test TIME-typed child should build MinutesOfTime for default precision value
+ val timeExpr = Literal(localTime(12, 58, 59), TimeType())
+ val builtExprForTime = MinuteExpressionBuilder.build("minute", Seq(timeExpr))
+ assert(builtExprForTime.isInstanceOf[MinutesOfTime])
+ assert(builtExprForTime.asInstanceOf[MinutesOfTime].child eq timeExpr)
+ assert(builtExprForTime.checkInputDataTypes().isSuccess)
+
+ // test TIME-typed child should build MinutesOfTime for all allowed custom precision values
+ (TimeType.MIN_PRECISION to TimeType.MICROS_PRECISION).foreach { precision =>
+ val timeExpr = Literal(localTime(12, 58, 59), TimeType(precision))
+ val builtExpr = MinuteExpressionBuilder.build("minute", Seq(timeExpr))
+
+ assert(builtExpr.isInstanceOf[MinutesOfTime])
+ assert(builtExpr.asInstanceOf[MinutesOfTime].child eq timeExpr)
+ assert(builtExpr.checkInputDataTypes().isSuccess)
+ }
+
+ // test non TIME-typed child should build Minute
+ val tsExpr = Literal("2009-07-30 12:58:59")
+ val builtExprForTs = MinuteExpressionBuilder.build("minute", Seq(tsExpr))
+ assert(builtExprForTs.isInstanceOf[Minute])
+ assert(builtExprForTs.asInstanceOf[Minute].child eq tsExpr)
+ }
+
+ test("Minute with TIME type") {
+ // A few test times in microseconds since midnight:
+ // time in microseconds -> expected minute
+ val testTimes = Seq(
+ localTime() -> 0,
+ localTime(1) -> 0,
+ localTime(0, 59) -> 59,
+ localTime(14, 30) -> 30,
+ localTime(12, 58, 59) -> 58,
+ localTime(23, 0, 1) -> 0,
+ localTime(23, 59, 59, 999999) -> 59
+ )
+
+ // Create a literal with TimeType() for each test microsecond value
+ // evaluate MinutesOfTime(...), and check that the result matches the expected minute.
+ testTimes.foreach { case (micros, expectedMinute) =>
+ checkEvaluation(
+ MinutesOfTime(Literal(micros, TimeType())),
+ expectedMinute)
+ }
+
+ // Verify NULL handling
+ checkEvaluation(
+ MinutesOfTime(Literal.create(null, TimeType(TimeType.MICROS_PRECISION))),
+ null
+ )
+
+ checkConsistencyBetweenInterpretedAndCodegen(
+ (child: Expression) => MinutesOfTime(child).replacement, TimeType())
+ }
+
+ test("creating values of TimeType via make_time") {
+ // basic case
+ checkEvaluation(
+ MakeTime(Literal(13), Literal(2), Literal(Decimal(23.5, 16, 6))),
+ LocalTime.of(13, 2, 23, 500000000))
+
+ // null cases
+ checkEvaluation(
+ MakeTime(Literal.create(null, IntegerType), Literal(18), Literal(Decimal(23.5, 16, 6))),
+ null)
+ checkEvaluation(
+ MakeTime(Literal(13), Literal.create(null, IntegerType), Literal(Decimal(23.5, 16, 6))),
+ null)
+ checkEvaluation(MakeTime(Literal(13), Literal(18), Literal.create(null, DecimalType(16, 6))),
+ null)
+
+ // Invalid cases
+ val errorCode = "DATETIME_FIELD_OUT_OF_BOUNDS.WITHOUT_SUGGESTION"
+ checkErrorInExpression[SparkDateTimeException](
+ MakeTime(Literal(25), Literal(2), Literal(Decimal(23.5, 16, 6))),
+ errorCode,
+ Map("rangeMessage" -> "Invalid value for HourOfDay (valid values 0 - 23): 25")
+ )
+ checkErrorInExpression[SparkDateTimeException](
+ MakeTime(Literal(23), Literal(-1), Literal(Decimal(23.5, 16, 6))),
+ errorCode,
+ Map("rangeMessage" -> "Invalid value for MinuteOfHour (valid values 0 - 59): -1")
+ )
+ checkErrorInExpression[SparkDateTimeException](
+ MakeTime(Literal(23), Literal(12), Literal(Decimal(100.5, 16, 6))),
+ errorCode,
+ Map("rangeMessage" -> "Invalid value for SecondOfMinute (valid values 0 - 59): 100")
+ )
+ }
+
+ test("SecondExpressionBuilder") {
+ // Empty expressions list
+ checkError(
+ exception = intercept[AnalysisException] {
+ SecondExpressionBuilder.build("second", Seq.empty)
+ },
+ condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+ parameters = Map(
+ "functionName" -> "`second`",
+ "expectedNum" -> "> 0",
+ "actualNum" -> "0",
+ "docroot" -> SPARK_DOC_ROOT)
+ )
+
+ // test TIME-typed child should build SecondsOfTime
+ val timeExpr = Literal(localTime(12, 58, 59), TimeType())
+ val builtExprForTime = SecondExpressionBuilder.build("second", Seq(timeExpr))
+ assert(builtExprForTime.isInstanceOf[SecondsOfTime])
+ assert(builtExprForTime.asInstanceOf[SecondsOfTime].child eq timeExpr)
+
+ // test non TIME-typed child should build second
+ val tsExpr = Literal("2007-09-03 10:45:23")
+ val builtExprForTs = SecondExpressionBuilder.build("second", Seq(tsExpr))
+ assert(builtExprForTs.isInstanceOf[Second])
+ assert(builtExprForTs.asInstanceOf[Second].child eq tsExpr)
+ }
+
+ test("Second with TIME type") {
+ // A few test times in microseconds since midnight:
+ // time in microseconds -> expected second
+ val testTimes = Seq(
+ localTime() -> 0,
+ localTime(1) -> 0,
+ localTime(0, 59) -> 0,
+ localTime(14, 30) -> 0,
+ localTime(12, 58, 59) -> 59,
+ localTime(23, 0, 1) -> 1,
+ localTime(23, 59, 59, 999999) -> 59
+ )
+
+ // Create a literal with TimeType() for each test microsecond value
+ // evaluate SecondsOfTime(...), and check that the result matches the expected second.
+ testTimes.foreach { case (micros, expectedSecond) =>
+ checkEvaluation(
+ SecondsOfTime(Literal(micros, TimeType())),
+ expectedSecond)
+ }
+
+ // Verify NULL handling
+ checkEvaluation(
+ SecondsOfTime(Literal.create(null, TimeType(TimeType.MICROS_PRECISION))),
+ null
+ )
+
+ checkConsistencyBetweenInterpretedAndCodegen(
+ (child: Expression) => SecondsOfTime(child).replacement, TimeType())
+ }
+
+ test("CurrentTime") {
+ // test valid precision
+ var expr = CurrentTime(Literal(3))
+ assert(expr.dataType == TimeType(3), "Should produce TIME(3) data type")
+ assert(expr.checkInputDataTypes() == TypeCheckSuccess)
+
+ // test default constructor => TIME(6)
+ expr = CurrentTime()
+ assert(expr.precision == 6, "Default precision should be 6")
+ assert(expr.dataType == TimeType(6))
+ assert(expr.checkInputDataTypes() == TypeCheckSuccess)
+
+ // test no value => TIME()
+ expr = CurrentTime()
+ assert(expr.precision == 6, "Default precision should be 6")
+ assert(expr.dataType == TimeType(6))
+ assert(expr.checkInputDataTypes() == TypeCheckSuccess)
+
+ // test foldable value
+ expr = CurrentTime(Literal(1 + 1))
+ assert(expr.precision == 2, "Precision should be 2")
+ assert(expr.dataType == TimeType(2))
+ assert(expr.checkInputDataTypes() == TypeCheckSuccess)
+
+ // test out of range precision => checkInputDataTypes fails
+ expr = CurrentTime(Literal(2 + 8))
+ assert(expr.checkInputDataTypes() ==
+ DataTypeMismatch(
+ errorSubClass = "VALUE_OUT_OF_RANGE",
+ messageParameters = Map(
+ "exprName" -> toSQLId("precision"),
+ "valueRange" -> s"[${TimeType.MIN_PRECISION}, ${TimeType.MICROS_PRECISION}]",
+ "currentValue" -> toSQLValue(10, IntegerType)
+ )
+ )
+ )
+
+ // test non number value should fail since we skip analyzer here
+ expr = CurrentTime(Literal("2"))
+ val failure = intercept[ClassCastException] {
+ expr.precision
+ }
+ assert(failure.getMessage.contains("cannot be cast to class java.lang.Number"))
+ }
+
+ test("Second with fraction from TIME type") {
+ val time = "13:11:15.987654321"
+ assert(
+ SecondsOfTimeWithFraction(
+ Cast(Literal(time), TimeType(TimeType.MICROS_PRECISION))).resolved)
+ assert(
+ SecondsOfTimeWithFraction(
+ Cast(Literal.create(time), TimeType(TimeType.MIN_PRECISION + 3))).resolved)
+ Seq(
+ 0 -> 15.0,
+ 1 -> 15.9,
+ 2 -> 15.98,
+ 3 -> 15.987,
+ 4 -> 15.9876,
+ 5 -> 15.98765,
+ 6 -> 15.987654).foreach { case (precision, expected) =>
+ checkEvaluation(
+ SecondsOfTimeWithFraction(Literal(localTime(13, 11, 15, 987654), TimeType(precision))),
+ BigDecimal(expected))
+ }
+ // Verify NULL handling
+ checkEvaluation(
+ SecondsOfTimeWithFraction(Literal.create(null, TimeType(TimeType.MICROS_PRECISION))),
+ null)
+ checkConsistencyBetweenInterpretedAndCodegen(
+ (child: Expression) => SecondsOfTimeWithFraction(child).replacement,
+ TimeType())
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
index 64529bf54bd22..5c297c00acc0f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
@@ -134,4 +134,11 @@ class ToPrettyStringSuite extends SparkFunSuite with ExpressionEvalHelper {
val prettyString = ToPrettyString(child)
assert(prettyString.sql === child.sql)
}
+
+ test("Time as pretty strings") {
+ checkEvaluation(ToPrettyString(Literal(1000L, TimeType())), "00:00:00.001")
+ checkEvaluation(ToPrettyString(Literal(1L, TimeType())), "00:00:00.000001")
+ checkEvaluation(ToPrettyString(Literal(
+ (23 * 3600 + 59 * 60 + 59) * 1000000L, TimeType())), "23:59:59")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index 117436a023938..1cd8cc6228efc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -455,7 +455,7 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
|"category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],
|"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}},
|"email":"amy@only_for_json_udf_test.net","owner":"amy","zip code":"94025",
- |"fb:testid":"1234"}
+ |"fb:testid":"1234","":"empty string","?":"Question Mark?", " ":"Whitespace", "\t": "Tab"}
|""".stripMargin
testVariantGet(json, "$.store.bicycle", StringType, """{"color":"red","price":19.95}""")
checkEvaluation(
@@ -469,6 +469,12 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
)
testVariantGet(json, "$.store.bicycle.color", StringType, "red")
testVariantGet(json, "$.store.bicycle.price", DoubleType, 19.95)
+ testVariantGet(json, "$[\"\"]", StringType, "empty string")
+ testVariantGet(json, "$['']", StringType, "empty string")
+ testVariantGet(json, "$[\"?\"]", StringType, "Question Mark?")
+ testVariantGet(json, "$[\" \"]", StringType, "Whitespace")
+ testVariantGet(json, "$[\"\t\"]", StringType, "Tab")
+ testVariantGet(json, "$['?']", StringType, "Question Mark?")
testVariantGet(
json,
"$.store.book",
@@ -678,6 +684,9 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkInvalidPath("$1")
checkInvalidPath("$[-1]")
checkInvalidPath("""$['"]""")
+
+ checkInvalidPath("$[\"\"\"]")
+ checkInvalidPath("$[\"\\\"\"]")
}
test("cast from variant") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
index 6e1c7fc887d4e..3fa6459a93e24 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
@@ -24,12 +24,13 @@ import scala.concurrent.duration._
import scala.jdk.CollectionConverters.MapHasAsScala
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Expression, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
+import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Cast, CurrentDate, CurrentTime, CurrentTimestamp, CurrentTimeZone, Expression, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.IntegerType
import org.apache.spark.unsafe.types.UTF8String
class ComputeCurrentTimeSuite extends PlanTest {
@@ -52,6 +53,83 @@ class ComputeCurrentTimeSuite extends PlanTest {
assert(lits(0) == lits(1))
}
+ test("analyzer should replace current_time with literals") {
+ // logical plan that calls current_time() twice in the Project
+ val planInput = Project(
+ Seq(
+ Alias(CurrentTime(Literal(3)), "a")(),
+ Alias(CurrentTime(Literal(3)), "b")()
+ ),
+ LocalRelation()
+ )
+
+ val analyzed = planInput.analyze
+ val optimized = Optimize.execute(analyzed).asInstanceOf[Project]
+
+ // We expect 2 literals in the final Project. Each literal is a Long
+ // representing microseconds since midnight, truncated to precision=3.
+ val lits = literals[Long](optimized) // a helper that extracts all Literal values of type Long
+ assert(lits.size == 2, s"Expected two literal values, found ${lits.size}")
+
+ // The rule should produce the same microsecond value for both columns "a" and "b".
+ assert(lits(0) == lits(1),
+ s"Expected both current_time(3) calls to yield the same literal, " +
+ s"but got ${lits(0)} vs ${lits(1)}")
+ }
+
+ test("analyzer should replace current_time with foldable child expressions") {
+ // We build a plan that calls current_time(2 + 1) twice
+ val foldableExpr = Add(Literal(2), Literal(1)) // a foldable arithmetic expression => 3
+ val planInput = Project(
+ Seq(
+ Alias(CurrentTime(foldableExpr), "a")(),
+ Alias(CurrentTime(foldableExpr), "b")()
+ ),
+ LocalRelation()
+ )
+
+ val analyzed = planInput.analyze
+ val optimized = Optimize.execute(analyzed).asInstanceOf[Project]
+
+ // We expect the optimizer to replace current_time(2 + 1) with a literal time value,
+ // so let's extract those literal values.
+ val lits = literals[Long](optimized)
+ assert(lits.size == 2, s"Expected two literal values, found ${lits.size}")
+
+ // Both references to current_time(2 + 1) should be replaced by the same microsecond-of-day
+ assert(lits(0) == lits(1),
+ s"Expected both current_time(2 + 1) calls to yield the same literal, " +
+ s"but got ${lits(0)} vs. ${lits(1)}"
+ )
+ }
+
+ test("analyzer should replace current_time with foldable casted string-literal") {
+ // We'll build a foldable cast expression: CAST(' 0005 ' AS INT) => 5
+ val castExpr = Cast(Literal(" 0005 "), IntegerType)
+
+ // Two references to current_time(castExpr) => so we can check they're replaced consistently
+ val planInput = Project(
+ Seq(
+ Alias(CurrentTime(castExpr), "a")(),
+ Alias(CurrentTime(castExpr), "b")()
+ ),
+ LocalRelation()
+ )
+
+ val analyzed = planInput.analyze
+ val optimized = Optimize.execute(analyzed).asInstanceOf[Project]
+
+ val lits = literals[Long](optimized)
+ assert(lits.size == 2, s"Expected two literal values, found ${lits.size}")
+
+ // Both references to current_time(CAST(' 0005 ' AS INT)) in the same query
+ // should produce the same microsecond-of-day literal.
+ assert(lits(0) == lits(1),
+ s"Expected both references to yield the same literal, but got ${lits(0)} vs. ${lits(1)}"
+ )
+ }
+
+
test("analyzer should respect time flow in current timestamp calls") {
val in = Project(Alias(CurrentTimestamp(), "t1")() :: Nil, LocalRelation())
@@ -65,6 +143,20 @@ class ComputeCurrentTimeSuite extends PlanTest {
assert(t2 - t1 <= 1000 && t2 - t1 > 0)
}
+ test("analyzer should respect time flow in current_time calls") {
+ val in = Project(Alias(CurrentTime(Literal(4)), "t1")() :: Nil, LocalRelation())
+
+ val planT1 = Optimize.execute(in.analyze).asInstanceOf[Project]
+ sleep(5)
+ val planT2 = Optimize.execute(in.analyze).asInstanceOf[Project]
+
+ val t1 = literals[Long](planT1)(0) // the microseconds-of-day for planT1
+ val t2 = literals[Long](planT2)(0) // the microseconds-of-day for planT2
+
+ assert(t2 > t1, s"Expected a newer time in the second analysis, but got t1=$t1, t2=$t2")
+ }
+
+
test("analyzer should replace current_date with literals") {
val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala
index 95b55797b294c..083c522287cab 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.logging.log4j.Level
+import org.slf4j.event.{Level => Slf4jLevel}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -103,8 +104,8 @@ class OptimizerLoggingSuite extends PlanTest {
val error = intercept[IllegalArgumentException] {
withSQLConf(SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> level) {}
}
- assert(error.getMessage.contains(
- "Invalid value for 'spark.sql.planChangeLog.level'."))
+ assert(error.getMessage == s"${SQLConf.PLAN_CHANGE_LOG_LEVEL.key} should be one of " +
+ s"${classOf[Slf4jLevel].getEnumConstants.mkString(", ")}, but was $level")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffsetSuite.scala
similarity index 56%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffsetSuite.scala
index 9af73158ee732..9a57630ebc13b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectionThroughLimitAndOffsetSuite.scala
@@ -20,16 +20,18 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Add
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-class PushProjectionThroughLimitSuite extends PlanTest {
+class PushProjectionThroughLimitAndOffsetSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Optimizer Batch",
FixedPoint(100),
- PushProjectionThroughLimit,
- EliminateLimits) :: Nil
+ PushProjectionThroughLimitAndOffset,
+ EliminateLimits,
+ LimitPushDown) :: Nil
}
test("SPARK-40501: push projection through limit") {
@@ -87,4 +89,74 @@ class PushProjectionThroughLimitSuite extends PlanTest {
.limit(10).analyze
comparePlans(optimized4, expected4)
}
+
+ test("push projection through offset") {
+ val testRelation = LocalRelation.fromExternalRows(
+ Seq("a".attr.int, "b".attr.int, "c".attr.int),
+ 1.to(30).map(_ => Row(1, 2, 3)))
+
+ val query1 = testRelation
+ .offset(5)
+ .select($"a", $"b", $"c")
+ .analyze
+ val optimized1 = Optimize.execute(query1)
+ val expected1 = testRelation
+ .select($"a", $"b", $"c")
+ .offset(5).analyze
+ comparePlans(optimized1, expected1)
+
+ val query2 = testRelation
+ .limit(15).offset(5)
+ .select($"a", $"b", $"c")
+ .analyze
+ val optimized2 = Optimize.execute(query2)
+ val expected2 = testRelation
+ .select($"a", $"b", $"c")
+ .limit(15).offset(5).analyze
+ comparePlans(optimized2, expected2)
+
+ val query3 = testRelation
+ .offset(5).limit(15)
+ .select($"a", $"b", $"c")
+ .analyze
+ val optimized3 = Optimize.execute(query3)
+ val expected3 = testRelation
+ .select($"a", $"b", $"c")
+ .localLimit(Add(15, 5)).offset(5).globalLimit(15)
+ .analyze
+ comparePlans(optimized3, expected3)
+
+ val query4 = testRelation
+ .offset(5).limit(15)
+ .select($"a", $"b", $"c")
+ .limit(10).analyze
+ val optimized4 = Optimize.execute(query4)
+ val expected4 = testRelation
+ .select($"a", $"b", $"c")
+ .localLimit(Add(10, 5)).offset(5).globalLimit(10)
+ .analyze
+ comparePlans(optimized4, expected4)
+
+ val query5 = testRelation
+ .localLimit(10)
+ .select($"a", $"b", $"c")
+ .offset(5).limit(10).analyze
+ val optimized5 = Optimize.execute(query5)
+ val expected5 = testRelation
+ .select($"a", $"b", $"c")
+ .localLimit(10).offset(5).globalLimit(10)
+ .analyze
+ comparePlans(optimized5, expected5)
+
+ val query6 = testRelation
+ .localLimit(20)
+ .select($"a", $"b", $"c")
+ .offset(5).limit(10).analyze
+ val optimized6 = Optimize.execute(query6)
+ val expected6 = testRelation
+ .select($"a", $"b", $"c")
+ .localLimit(15).offset(5).globalLimit(10)
+ .analyze
+ comparePlans(optimized6, expected6)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala
index 8a0a0466ca741..552a638f6e614 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala
@@ -178,4 +178,64 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest {
comparePlans(optimized, expectedWhenNotExcluded)
}
}
+
+ test("SPARK-46640: exclude outer references accounts for children of plan expression") {
+ val a = $"a".int
+ val a_alias = Alias(a, "a")()
+ val a_alias_attr = a_alias.toAttribute
+
+ // The original input query
+ // Project [CASE WHEN exists#2 [a#1 && (a#1 = a#0)] THEN 1 ELSE 2 END AS result#3]
+ // : +- LocalRelation , [a#0]
+ // +- Project [a#0 AS a#1]
+ // +- LocalRelation , [a#0]
+ // The subquery expression (`exists#2`) is wrapped in a CaseWhen and an Alias.
+ // Without the fix on excluding outer references, the rewritten plan would have been:
+ // Project [CASE WHEN exists#2 [a#0 && (a#0 = a#0)] THEN 1 ELSE 2 END AS result#3]
+ // : +- LocalRelation , [a#0]
+ // +- LocalRelation , [a#0]
+ // This plan would then fail later with the error -- conflicting a#0 in join condition.
+
+ val query = Project(Seq(
+ Alias(
+ CaseWhen(Seq((
+ Exists(
+ LocalRelation(a),
+ outerAttrs = Seq(a_alias_attr),
+ joinCond = Seq(EqualTo(a_alias_attr, a))
+ ), Literal(1))),
+ Some(Literal(2))),
+ "result"
+ )()),
+ Project(Seq(a_alias), LocalRelation(a))
+ )
+
+ // The alias would not be removed if excluding subquery references is enabled.
+ val expectedWhenExcluded = query
+
+ // The alias would be removed and we would have conflicting expression ID(s) in the join cond
+ val expectedWhenNotEnabled = Project(Seq(
+ Alias(
+ CaseWhen(Seq((
+ Exists(
+ LocalRelation(a),
+ outerAttrs = Seq(a),
+ joinCond = Seq(EqualTo(a, a))
+ ), Literal(1))),
+ Some(Literal(2))),
+ "result"
+ )()),
+ LocalRelation(a)
+ )
+
+ withSQLConf(SQLConf.EXCLUDE_SUBQUERY_EXP_REFS_FROM_REMOVE_REDUNDANT_ALIASES.key -> "true") {
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, expectedWhenExcluded)
+ }
+
+ withSQLConf(SQLConf.EXCLUDE_SUBQUERY_EXP_REFS_FROM_REMOVE_REDUNDANT_ALIASES.key -> "false") {
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, expectedWhenNotEnabled)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index c8d2de9c6b8de..1589bcb8a3d7e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -2705,7 +2705,7 @@ class DDLParserSuite extends AnalysisTest {
val createTableResult =
CreateTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithDefaultValue,
Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"),
- OptionList(Seq.empty), None, None, None, None, false), false)
+ OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false)
// Parse the CREATE TABLE statement twice, swapping the order of the NOT NULL and DEFAULT
// options, to make sure that the parser accepts any ordering of these options.
comparePlans(parsePlan(
@@ -2718,7 +2718,7 @@ class DDLParserSuite extends AnalysisTest {
"b STRING NOT NULL DEFAULT 'abc') USING parquet"),
ReplaceTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithDefaultValue,
Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"),
- OptionList(Seq.empty), None, None, None, None, false), false))
+ OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false))
// These ALTER TABLE statements should parse successfully.
comparePlans(
parsePlan("ALTER TABLE t1 ADD COLUMN x int NOT NULL DEFAULT 42"),
@@ -2881,12 +2881,12 @@ class DDLParserSuite extends AnalysisTest {
"CREATE TABLE my_tab(a INT, b INT NOT NULL GENERATED ALWAYS AS (a+1)) USING parquet"),
CreateTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithGenerationExpr,
Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"),
- OptionList(Seq.empty), None, None, None, None, false), false))
+ OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false))
comparePlans(parsePlan(
"REPLACE TABLE my_tab(a INT, b INT NOT NULL GENERATED ALWAYS AS (a+1)) USING parquet"),
ReplaceTable(UnresolvedIdentifier(Seq("my_tab")), columnsWithGenerationExpr,
Seq.empty[Transform], UnresolvedTableSpec(Map.empty[String, String], Some("parquet"),
- OptionList(Seq.empty), None, None, None, None, false), false))
+ OptionList(Seq.empty), None, None, None, None, false, Seq.empty), false))
// Two generation expressions
checkError(
exception = parseException("CREATE TABLE my_tab(a INT, " +
@@ -2957,7 +2957,8 @@ class DDLParserSuite extends AnalysisTest {
None,
None,
None,
- false
+ false,
+ Seq.empty
),
false
)
@@ -2980,7 +2981,8 @@ class DDLParserSuite extends AnalysisTest {
None,
None,
None,
- false
+ false,
+ Seq.empty
),
false
)
@@ -3273,7 +3275,7 @@ class DDLParserSuite extends AnalysisTest {
Seq(ColumnDefinition("c", StringType)),
Seq.empty[Transform],
UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), OptionList(Seq.empty),
- None, None, Some(collation), None, false), false))
+ None, None, Some(collation), None, false, Seq.empty), false))
}
}
@@ -3285,7 +3287,7 @@ class DDLParserSuite extends AnalysisTest {
Seq(ColumnDefinition("c", StringType)),
Seq.empty[Transform],
UnresolvedTableSpec(Map.empty[String, String], Some("parquet"), OptionList(Seq.empty),
- None, None, Some(collation), None, false), false))
+ None, None, Some(collation), None, false, Seq.empty), false))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index c416d21cfd4b0..8b61328a00999 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.parser
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.TimestampTypes
@@ -57,6 +57,9 @@ class DataTypeParserSuite extends SparkFunSuite with SQLHelper {
checkDataType("Dec(10, 5)", DecimalType(10, 5))
checkDataType("deC", DecimalType.USER_DEFAULT)
checkDataType("DATE", DateType)
+ checkDataType("TimE", TimeType())
+ checkDataType("time(0)", TimeType(0))
+ checkDataType("TIME(6)", TimeType(6))
checkDataType("timestamp", TimestampType)
checkDataType("timestamp_ntz", TimestampNTZType)
checkDataType("timestamp_ltz", TimestampType)
@@ -172,4 +175,19 @@ class DataTypeParserSuite extends SparkFunSuite with SQLHelper {
// DataType parser accepts comments.
checkDataType("Struct",
(new StructType).add("x", IntegerType).add("y", StringType, true, "test"))
+
+ test("unsupported precision of the time data type") {
+ checkError(
+ exception = intercept[SparkException] {
+ CatalystSqlParser.parseDataType("time(9)")
+ },
+ condition = "UNSUPPORTED_TIME_PRECISION",
+ parameters = Map("precision" -> "9"))
+ checkError(
+ exception = intercept[ParseException] {
+ CatalystSqlParser.parseDataType("time(-1)")
+ },
+ condition = "PARSE_SYNTAX_ERROR",
+ parameters = Map("error" -> "'('", "hint" -> ""))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index f57b2230740ba..17b03946251a2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.parser
import java.sql.{Date, Timestamp}
-import java.time.{Duration, LocalDateTime, Period}
+import java.time.{Duration, LocalDateTime, LocalTime, Period}
import java.util.concurrent.TimeUnit
import scala.language.implicitConversions
@@ -1194,16 +1194,18 @@ class ExpressionParserSuite extends AnalysisTest {
}
}
- test("current date/timestamp braceless expressions") {
+ test("current date/timestamp/time braceless expressions") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true",
SQLConf.ENFORCE_RESERVED_KEYWORDS.key -> "true") {
assertEqual("current_date", CurrentDate())
assertEqual("current_timestamp", CurrentTimestamp())
+ assertEqual("current_time", CurrentTime())
}
def testNonAnsiBehavior(): Unit = {
assertEqual("current_date", UnresolvedAttribute.quoted("current_date"))
assertEqual("current_timestamp", UnresolvedAttribute.quoted("current_timestamp"))
+ assertEqual("current_time", UnresolvedAttribute.quoted("current_time"))
}
withSQLConf(
SQLConf.ANSI_ENABLED.key -> "false",
@@ -1238,4 +1240,19 @@ class ExpressionParserSuite extends AnalysisTest {
stop = 9 + quantifier.length))
}
}
+
+ test("time literals") {
+ assertEqual("tIme '12:13:14'", Literal(LocalTime.parse("12:13:14")))
+ assertEqual("TIME'23:59:59.999999'", Literal(LocalTime.parse("23:59:59.999999")))
+
+ checkError(
+ exception = parseException("time '12-13.14'"),
+ condition = "INVALID_TYPED_LITERAL",
+ sqlState = "42604",
+ parameters = Map("valueType" -> "\"TIME\"", "value" -> "'12-13.14'"),
+ context = ExpectedContext(
+ fragment = "time '12-13.14'",
+ start = 0,
+ stop = 14))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala
index 41c87dd804be1..f281c42bbe715 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit
import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.getZoneId
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, localTimeToMicros}
/**
* Helper functions for testing date and time functionality.
@@ -112,4 +112,15 @@ object DateTimeTestUtils {
result = Math.addExact(result, Math.multiplyExact(seconds, MICROS_PER_SECOND))
result
}
+
+ // Returns microseconds since midnight
+ def localTime(
+ hour: Byte = 0,
+ minute: Byte = 0,
+ sec: Byte = 0,
+ micros: Int = 0): Long = {
+ val nanos = TimeUnit.MICROSECONDS.toNanos(micros).toInt
+ val localTime = LocalTime.of(hour, minute, sec, nanos)
+ localTimeToMicros(localTime)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index c253272e2bbb7..24258a2268ba6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -19,19 +19,20 @@ package org.apache.spark.sql.catalyst.util
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
-import java.time.{Instant, LocalDate, LocalDateTime, LocalTime, ZoneId}
+import java.time.{DateTimeException, Instant, LocalDate, LocalDateTime, LocalTime, ZoneId}
import java.util.Locale
import java.util.concurrent.TimeUnit
import org.scalatest.matchers.must.Matchers
import org.scalatest.matchers.should.Matchers._
-import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException}
+import org.apache.spark.{SparkDateTimeException, SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.RebaseDateTime.rebaseJulianToGregorianMicros
+import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
@@ -763,10 +764,10 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
}
test("SPARK-35664: microseconds to LocalDateTime") {
- assert(microsToLocalDateTime(0) == LocalDateTime.parse("1970-01-01T00:00:00"))
- assert(microsToLocalDateTime(100) == LocalDateTime.parse("1970-01-01T00:00:00.0001"))
- assert(microsToLocalDateTime(100000000) == LocalDateTime.parse("1970-01-01T00:01:40"))
- assert(microsToLocalDateTime(100000000000L) == LocalDateTime.parse("1970-01-02T03:46:40"))
+ assert(microsToLocalDateTime(0) == LocalDateTime.parse("1970-01-01T00:00:00"))
+ assert(microsToLocalDateTime(100) == LocalDateTime.parse("1970-01-01T00:00:00.0001"))
+ assert(microsToLocalDateTime(100000000) == LocalDateTime.parse("1970-01-01T00:01:40"))
+ assert(microsToLocalDateTime(100000000000L) == LocalDateTime.parse("1970-01-02T03:46:40"))
assert(microsToLocalDateTime(253402300799999999L) ==
LocalDateTime.parse("9999-12-31T23:59:59.999999"))
assert(microsToLocalDateTime(Long.MinValue) ==
@@ -1105,4 +1106,112 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
"parameter" -> "`unit`",
"invalidValue" -> "'SECS'"))
}
+
+ test("localTimeToMicros and microsToLocalTime") {
+ assert(microsToLocalTime(0) === LocalTime.of(0, 0))
+ assert(localTimeToMicros(LocalTime.of(0, 0)) === 0)
+
+ assert(localTimeToMicros(microsToLocalTime(123456789)) === 123456789)
+
+ assert(localTimeToMicros(LocalTime.parse("23:59:59.999999")) === (24L * 60 * 60 * 1000000 - 1))
+ assert(microsToLocalTime(24L * 60 * 60 * 1000000 - 1) === LocalTime.of(23, 59, 59, 999999000))
+
+ Seq(-1, 24L * 60 * 60 * 1000000).foreach { invalidMicros =>
+ val msg = intercept[DateTimeException] {
+ microsToLocalTime(invalidMicros)
+ }.getMessage
+ assert(msg.contains("Invalid value"))
+ }
+ val msg = intercept[ArithmeticException] {
+ microsToLocalTime(Long.MaxValue)
+ }.getMessage
+ assert(msg == "long overflow")
+ }
+
+ test("stringToTime") {
+ def checkStringToTime(str: String, expected: Option[Long]): Unit = {
+ assert(stringToTime(UTF8String.fromString(str)) === expected)
+ }
+
+ checkStringToTime("00:00", Some(localTime()))
+ checkStringToTime("00:00:00", Some(localTime()))
+ checkStringToTime("00:00:00.1", Some(localTime(micros = 100000)))
+ checkStringToTime("00:00:59.01", Some(localTime(sec = 59, micros = 10000)))
+ checkStringToTime("00:59:00.001", Some(localTime(minute = 59, micros = 1000)))
+ checkStringToTime("23:00:00.0001", Some(localTime(hour = 23, micros = 100)))
+ checkStringToTime("23:59:00.00001", Some(localTime(hour = 23, minute = 59, micros = 10)))
+ checkStringToTime("23:59:59.000001",
+ Some(localTime(hour = 23, minute = 59, sec = 59, micros = 1)))
+ checkStringToTime("23:59:59.999999",
+ Some(localTime(hour = 23, minute = 59, sec = 59, micros = 999999)))
+
+ checkStringToTime("1:2:3.0", Some(localTime(hour = 1, minute = 2, sec = 3)))
+ checkStringToTime("T1:02:3.04", Some(localTime(hour = 1, minute = 2, sec = 3, micros = 40000)))
+
+ // Negative tests
+ Seq("2025-03-09 00:00:00", "00", "00:01:02 UTC").foreach { invalidTime =>
+ checkStringToTime(invalidTime, None)
+ }
+ }
+
+ test("stringToTimeAnsi") {
+ Seq("2025-03-09T00:00:00", "012", "00:01:02Z").foreach { invalidTime =>
+ checkError(
+ exception = intercept[SparkDateTimeException] {
+ stringToTimeAnsi(UTF8String.fromString(invalidTime))
+ },
+ condition = "CAST_INVALID_INPUT",
+ parameters = Map(
+ "expression" -> s"'$invalidTime'",
+ "sourceType" -> "\"STRING\"",
+ "targetType" -> "\"TIME(6)\"",
+ "ansiConfig" -> "\"spark.sql.ansi.enabled\""))
+ }
+ }
+
+ test("timeToMicros") {
+ val hour = 13
+ val min = 2
+ val sec = 23
+ val micros = 1234
+ val secAndMicros = Decimal(sec + (micros / MICROS_PER_SECOND.toFloat), 16, 6)
+
+ // Valid case
+ val microSecsTime = timeToMicros(hour, min, secAndMicros)
+ assert(microSecsTime === localTime(hour.toByte, min.toByte, sec.toByte, micros))
+
+ // Invalid hour
+ checkError(
+ exception = intercept[SparkDateTimeException] {
+ timeToMicros(-1, min, secAndMicros)
+ },
+ condition = "DATETIME_FIELD_OUT_OF_BOUNDS.WITHOUT_SUGGESTION",
+ parameters = Map("rangeMessage" -> "Invalid value for HourOfDay (valid values 0 - 23): -1"))
+
+ // Invalid minute
+ checkError(
+ exception = intercept[SparkDateTimeException] {
+ timeToMicros(hour, -1, secAndMicros)
+ },
+ condition = "DATETIME_FIELD_OUT_OF_BOUNDS.WITHOUT_SUGGESTION",
+ parameters = Map("rangeMessage" ->
+ "Invalid value for MinuteOfHour (valid values 0 - 59): -1"))
+
+ // Invalid second cases
+ Seq(
+ 60.0,
+ 9999999999.999999,
+ -999999999.999999,
+ // Full seconds overflows to a valid seconds integer when converted from long to int
+ 4294967297.999999
+ ).foreach { invalidSecond =>
+ checkError(
+ exception = intercept[SparkDateTimeException] {
+ timeToMicros(hour, min, Decimal(invalidSecond, 16, 6))
+ },
+ condition = "DATETIME_FIELD_OUT_OF_BOUNDS.WITHOUT_SUGGESTION",
+ parameters = Map("rangeMessage" ->
+ s"Invalid value for SecondOfMinute (valid values 0 - 59): ${invalidSecond.toLong}"))
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimeFormatterSuite.scala
new file mode 100644
index 0000000000000..d99ea2bd1042b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimeFormatterSuite.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import java.time.DateTimeException
+
+import scala.util.Random
+
+import org.apache.spark.{SPARK_DOC_ROOT, SparkDateTimeException, SparkFunSuite, SparkRuntimeException}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToLocalTime
+
+class TimeFormatterSuite extends SparkFunSuite with SQLHelper {
+
+ test("time parsing") {
+ Seq(
+ ("12", "HH") -> 12 * 3600 * 1000000L,
+ ("01:02", "HH:mm") -> (1 * 3600 + 2 * 60) * 1000000L,
+ ("10:20", "HH:mm") -> (10 * 3600 + 20 * 60) * 1000000L,
+ ("00:00:00", "HH:mm:ss") -> 0L,
+ ("01:02:03", "HH:mm:ss") -> (1 * 3600 + 2 * 60 + 3) * 1000000L,
+ ("23:59:59", "HH:mm:ss") -> (23 * 3600 + 59 * 60 + 59) * 1000000L,
+ ("00:00:00.000000", "HH:mm:ss.SSSSSS") -> 0L,
+ ("12:34:56.789012", "HH:mm:ss.SSSSSS") -> ((12 * 3600 + 34 * 60 + 56) * 1000000L + 789012),
+ ("23:59:59.000000", "HH:mm:ss.SSSSSS") -> (23 * 3600 + 59 * 60 + 59) * 1000000L,
+ ("23:59:59.999999", "HH:mm:ss.SSSSSS") -> ((23 * 3600 + 59 * 60 + 59) * 1000000L + 999999)
+ ).foreach { case ((inputStr, pattern), expectedMicros) =>
+ val formatter = TimeFormatter(format = pattern, isParsing = true)
+ assert(formatter.parse(inputStr) === expectedMicros)
+ }
+ }
+
+ test("time strings do not match to the pattern") {
+ def assertError(str: String, expectedMsg: String): Unit = {
+ val e = intercept[DateTimeException] {
+ TimeFormatter(format = "HH:mm:ss", isParsing = true).parse(str)
+ }
+ assert(e.getMessage.contains(expectedMsg))
+ }
+
+ assertError("11.12.13", "Text '11.12.13' could not be parsed")
+ assertError("25:00:00", "Text '25:00:00' could not be parsed: Invalid value")
+ }
+
+ test("time formatting") {
+ Seq(
+ (12 * 3600 * 1000000L, "HH") -> "12",
+ ((1 * 3600 + 2 * 60) * 1000000L, "HH:mm") -> "01:02",
+ ((10 * 3600 + 20 * 60) * 1000000L, "HH:mm") -> "10:20",
+ (0L, "HH:mm:ss") -> "00:00:00",
+ ((1 * 3600 + 2 * 60 + 3) * 1000000L, "HH:mm:ss") -> "01:02:03",
+ ((23 * 3600 + 59 * 60 + 59) * 1000000L, "HH:mm:ss") -> "23:59:59",
+ (0L, "HH:mm:ss.SSSSSS") -> "00:00:00.000000",
+ ((12 * 3600 + 34 * 60 + 56) * 1000000L + 789012, "HH:mm:ss.SSSSSS") -> "12:34:56.789012",
+ ((23 * 3600 + 59 * 60 + 59) * 1000000L, "HH:mm:ss.SSSSSS") -> "23:59:59.000000",
+ ((23 * 3600 + 59 * 60 + 59) * 1000000L + 999999, "HH:mm:ss.SSSSSS") -> "23:59:59.999999"
+ ).foreach { case ((micros, pattern), expectedStr) =>
+ val formatter = TimeFormatter(format = pattern)
+ assert(formatter.format(micros) === expectedStr)
+ }
+ }
+
+ test("micros are out of supported range") {
+ def assertError(micros: Long, expectedMsg: String): Unit = {
+ val e = intercept[DateTimeException](TimeFormatter(isParsing = false).format(micros))
+ assert(e.getMessage.contains(expectedMsg))
+ }
+
+ assertError(-1, "Invalid value for NanoOfDay (valid values 0 - 86399999999999): -1000")
+ assertError(25L * 3600 * 1000 * 1000,
+ "Invalid value for NanoOfDay (valid values 0 - 86399999999999): 90000000000000")
+ }
+
+ test("invalid pattern") {
+ Seq("hHH:mmm:s", "kkk", "GGGGGG").foreach { invalidPattern =>
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ TimeFormatter(invalidPattern)
+ },
+ condition = "INVALID_DATETIME_PATTERN.WITH_SUGGESTION",
+ parameters = Map(
+ "pattern" -> s"'$invalidPattern'",
+ "docroot" -> SPARK_DOC_ROOT))
+ }
+ }
+
+ test("round trip with the default pattern: format -> parse") {
+ val data = Seq.tabulate(10) { _ => Random.between(0, 24 * 60 * 60 * 1000000L) }
+ val pattern = "HH:mm:ss.SSSSSS"
+ val (formatter, parser) =
+ (TimeFormatter(pattern, isParsing = false), TimeFormatter(pattern, isParsing = true))
+ data.foreach { micros =>
+ val str = formatter.format(micros)
+ assert(parser.parse(str) === micros, s"micros = $micros")
+ assert(formatter.format(microsToLocalTime(micros)) === str)
+ }
+ }
+
+ test("format fraction of second") {
+ val formatter = new FractionTimeFormatter()
+ Seq(
+ 0 -> "00:00:00",
+ 1 -> "00:00:00.000001",
+ 1000 -> "00:00:00.001",
+ 900000 -> "00:00:00.9",
+ 1000000 -> "00:00:01").foreach { case (micros, tsStr) =>
+ assert(formatter.format(micros) === tsStr)
+ assert(formatter.format(microsToLocalTime(micros)) === tsStr)
+ }
+ }
+
+ test("missing am/pm field") {
+ Seq("HH", "hh", "KK", "kk").foreach { hour =>
+ val formatter = TimeFormatter(format = s"$hour:mm:ss", isParsing = true)
+ val micros = formatter.parse("11:30:01")
+ assert(micros === localTime(11, 30, 1))
+ }
+ }
+
+ test("missing hour field") {
+ val f1 = TimeFormatter(format = "mm:ss a", isParsing = true)
+ val t1 = f1.parse("30:01 PM")
+ assert(t1 === localTime(12, 30, 1))
+ val t2 = f1.parse("30:01 AM")
+ assert(t2 === localTime(0, 30, 1))
+ val f2 = TimeFormatter(format = "mm:ss", isParsing = true)
+ val t3 = f2.parse("30:01")
+ assert(t3 === localTime(0, 30, 1))
+ val f3 = TimeFormatter(format = "a", isParsing = true)
+ val t4 = f3.parse("PM")
+ assert(t4 === localTime(12))
+ val t5 = f3.parse("AM")
+ assert(t5 === localTime())
+ }
+
+ test("default parsing w/o pattern") {
+ val formatter = new DefaultTimeFormatter(
+ locale = DateFormatter.defaultLocale,
+ isParsing = true)
+ Seq(
+ "00:00:00" -> localTime(),
+ "00:00:00.000001" -> localTime(micros = 1),
+ "01:02:03" -> localTime(1, 2, 3),
+ "1:2:3.999999" -> localTime(1, 2, 3, 999999),
+ "23:59:59.1" -> localTime(23, 59, 59, 100000)
+ ).foreach { case (inputStr, micros) =>
+ assert(formatter.parse(inputStr) === micros)
+ }
+
+ checkError(
+ exception = intercept[SparkDateTimeException] {
+ formatter.parse("x123")
+ },
+ condition = "CAST_INVALID_INPUT",
+ parameters = Map(
+ "expression" -> "'x123'",
+ "sourceType" -> "\"STRING\"",
+ "targetType" -> "\"TIME(6)\"",
+ "ansiConfig" -> "\"spark.sql.ansi.enabled\""))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
index 51ea945984b50..19df744145695 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
@@ -23,12 +23,14 @@ import java.util.Collections
import scala.jdk.CollectionConverters._
import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException, SparkUnsupportedOperationException}
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.quoteIdentifier
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
-import org.apache.spark.sql.connector.expressions.{Expressions, LogicalExpressions, Transform}
+import org.apache.spark.sql.connector.expressions.{Expressions, FieldReference, LogicalExpressions, Transform}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -57,6 +59,14 @@ class CatalogSuite extends SparkFunSuite {
private val testIdentNewQuoted = testIdentNew.asMultipartIdentifier
.map(part => quoteIdentifier(part)).mkString(".")
+ private val constraints: Array[Constraint] = Array(
+ Constraint.primaryKey("pk", Array(FieldReference.column("id"))).build(),
+ Constraint.check("chk").predicateSql("id > 0").build(),
+ Constraint.unique("uk", Array(FieldReference.column("data"))).build(),
+ Constraint.foreignKey("fk", Array(FieldReference.column("data")), testIdentNew,
+ Array(FieldReference.column("id"))).build()
+ )
+
test("Catalogs can load the catalog") {
val catalog = newCatalog()
@@ -75,7 +85,11 @@ class CatalogSuite extends SparkFunSuite {
intercept[NoSuchNamespaceException](catalog.listTables(Array("ns")))
- catalog.createTable(ident1, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps).build()
+ catalog.createTable(ident1, tableInfo)
assert(catalog.listTables(Array("ns")).toSet == Set(ident1))
intercept[NoSuchNamespaceException](catalog.listTables(Array("ns2")))
@@ -101,7 +115,12 @@ class CatalogSuite extends SparkFunSuite {
assert(!catalog.tableExists(testIdent))
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name)
assert(parsed == Seq("test", "`", ".", "test_table"))
@@ -120,11 +139,12 @@ class CatalogSuite extends SparkFunSuite {
val columns = Array(
Column.create("col0", IntegerType),
Column.create("part0", IntegerType))
- val table = partCatalog.createTable(
- testIdent,
- columns,
- Array[Transform](Expressions.identity("part0")),
- util.Collections.emptyMap[String, String])
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(Array[Transform](Expressions.identity("part0")))
+ .withProperties(util.Collections.emptyMap[String, String])
+ .build()
+ val table = partCatalog.createTable(testIdent, tableInfo)
val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name)
assert(parsed == Seq("test", "`", ".", "test_table"))
@@ -142,7 +162,12 @@ class CatalogSuite extends SparkFunSuite {
assert(!catalog.tableExists(testIdent))
- val table = catalog.createTable(testIdent, columns, emptyTrans, properties)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(properties)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name)
assert(parsed == Seq("test", "`", ".", "test_table"))
@@ -152,15 +177,43 @@ class CatalogSuite extends SparkFunSuite {
assert(catalog.tableExists(testIdent))
}
+ test("createTable: with constraints") {
+ val catalog = newCatalog()
+
+ val columns = Array(
+ Column.create("id", IntegerType, false),
+ Column.create("data", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .withConstraints(constraints)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
+
+ val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name)
+ assert(parsed == Seq("test", "`", ".", "test_table"))
+ assert(table.columns === columns)
+ assert(table.constraints === constraints)
+ assert(table.properties.asScala == Map())
+
+ assert(catalog.tableExists(testIdent))
+ }
+
test("createTable: table already exists") {
val catalog = newCatalog()
assert(!catalog.tableExists(testIdent))
- catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ catalog.createTable(testIdent, tableInfo)
val exc = intercept[TableAlreadyExistsException] {
- catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ catalog.createTable(testIdent, tableInfo)
}
checkErrorTableAlreadyExists(exc, testIdentQuoted)
@@ -173,7 +226,12 @@ class CatalogSuite extends SparkFunSuite {
assert(!catalog.tableExists(testIdent))
- catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ catalog.createTable(testIdent, tableInfo)
assert(catalog.tableExists(testIdent))
@@ -185,7 +243,12 @@ class CatalogSuite extends SparkFunSuite {
test("loadTable") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
val loaded = catalog.loadTable(testIdent)
assert(table.name == loaded.name)
@@ -206,7 +269,12 @@ class CatalogSuite extends SparkFunSuite {
test("invalidateTable") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
catalog.invalidateTable(testIdent)
val loaded = catalog.loadTable(testIdent)
@@ -227,7 +295,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: add property") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.properties.asScala == Map())
@@ -246,7 +319,12 @@ class CatalogSuite extends SparkFunSuite {
val properties = new util.HashMap[String, String]()
properties.put("prop-1", "1")
- val table = catalog.createTable(testIdent, columns, emptyTrans, properties)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(properties)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.properties.asScala == Map("prop-1" -> "1"))
@@ -265,7 +343,12 @@ class CatalogSuite extends SparkFunSuite {
val properties = new util.HashMap[String, String]()
properties.put("prop-1", "1")
- val table = catalog.createTable(testIdent, columns, emptyTrans, properties)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(properties)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.properties.asScala == Map("prop-1" -> "1"))
@@ -281,7 +364,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: remove missing property") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.properties.asScala == Map())
@@ -297,7 +385,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: add top-level column") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -309,7 +402,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: add required column") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -322,7 +420,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: add column with comment") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -339,7 +442,12 @@ class CatalogSuite extends SparkFunSuite {
val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
val tableColumns = columns :+ Column.create("point", pointStruct)
- val table = catalog.createTable(testIdent, tableColumns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === tableColumns)
@@ -354,7 +462,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: add column to primitive field fails") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -372,7 +485,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: add field to missing column fails") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -388,7 +506,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: update column data type") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -404,7 +527,12 @@ class CatalogSuite extends SparkFunSuite {
val originalColumns = Array(
Column.create("id", IntegerType, false),
Column.create("data", StringType))
- val table = catalog.createTable(testIdent, originalColumns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(originalColumns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === originalColumns)
@@ -418,7 +546,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: update missing column fails") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -434,7 +567,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: add comment") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -450,7 +588,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: replace comment") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -469,7 +612,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: add comment to missing column fails") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -485,7 +633,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: rename top-level column") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -502,7 +655,12 @@ class CatalogSuite extends SparkFunSuite {
val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
val tableColumns = columns :+ Column.create("point", pointStruct)
- val table = catalog.createTable(testIdent, tableColumns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === tableColumns)
@@ -521,7 +679,12 @@ class CatalogSuite extends SparkFunSuite {
val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
val tableColumns = columns :+ Column.create("point", pointStruct)
- val table = catalog.createTable(testIdent, tableColumns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === tableColumns)
@@ -537,7 +700,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: rename missing column fails") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -556,7 +724,12 @@ class CatalogSuite extends SparkFunSuite {
val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
val tableColumns = columns :+ Column.create("point", pointStruct)
- val table = catalog.createTable(testIdent, tableColumns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === tableColumns)
@@ -573,7 +746,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: delete top-level column") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -590,7 +768,12 @@ class CatalogSuite extends SparkFunSuite {
val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
val tableColumns = columns :+ Column.create("point", pointStruct)
- val table = catalog.createTable(testIdent, tableColumns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === tableColumns)
@@ -606,7 +789,12 @@ class CatalogSuite extends SparkFunSuite {
test("alterTable: delete missing column fails") {
val catalog = newCatalog()
- val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === columns)
@@ -628,7 +816,12 @@ class CatalogSuite extends SparkFunSuite {
val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
val tableColumns = columns :+ Column.create("point", pointStruct)
- val table = catalog.createTable(testIdent, tableColumns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
assert(table.columns === tableColumns)
@@ -654,12 +847,120 @@ class CatalogSuite extends SparkFunSuite {
checkErrorTableNotFound(exc, testIdentQuoted)
}
+ test("alterTable: add constraint") {
+ val catalog = newCatalog()
+
+ val tableColumns = Array(
+ Column.create("id", IntegerType, false),
+ Column.create("data", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
+
+ assert(table.constraints.isEmpty)
+
+ for ((constraint, index) <- constraints.zipWithIndex) {
+ val updated = catalog.alterTable(testIdent, TableChange.addConstraint(constraint, null))
+ assert(updated.constraints.length === index + 1)
+ assert(updated.constraints.apply(index) === constraint)
+ }
+ }
+
+ test("alterTable: add existing constraint should fail") {
+ val catalog = newCatalog()
+
+ val tableColumns = Array(
+ Column.create("id", IntegerType, false),
+ Column.create("data", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withConstraints(constraints)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
+
+ assert(table.constraints.length === constraints.length)
+
+ for (constraint <- constraints) {
+ checkError(
+ exception = intercept[AnalysisException] {
+ catalog.alterTable(testIdent, TableChange.addConstraint(constraint, null))
+ },
+ condition = "CONSTRAINT_ALREADY_EXISTS",
+ parameters = Map("constraintName" -> constraint.name, "oldConstraint" -> constraint.toDDL))
+ }
+ }
+
+ test("alterTable: drop constraint") {
+ val catalog = newCatalog()
+
+ val tableColumns = Array(
+ Column.create("id", IntegerType, false),
+ Column.create("data", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withConstraints(constraints)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
+
+ assert(table.constraints.length === constraints.length)
+
+ for ((constraint, index) <- constraints.zipWithIndex) {
+ val updated =
+ catalog.alterTable(testIdent, TableChange.dropConstraint(constraint.name(), false, false))
+ assert(updated.constraints.length === constraints.length - index -1)
+ }
+ }
+
+ test("alterTable: drop non-existing constraint") {
+ val catalog = newCatalog()
+
+ val tableColumns = Array(
+ Column.create("id", IntegerType, false),
+ Column.create("data", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withConstraints(constraints)
+ .build()
+ val table = catalog.createTable(testIdent, tableInfo)
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ catalog.alterTable(testIdent,
+ TableChange.dropConstraint("missing_constraint", false, false))
+ },
+ condition = "CONSTRAINT_DOES_NOT_EXIST",
+ parameters = Map("constraintName" -> "missing_constraint",
+ "tableName" -> table.name()))
+ }
+
+ test("alterTable: drop non-existing constraint if exists") {
+ val catalog = newCatalog()
+
+ val tableColumns = Array(
+ Column.create("id", IntegerType, false),
+ Column.create("data", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(tableColumns)
+ .withConstraints(constraints)
+ .build()
+ catalog.createTable(testIdent, tableInfo)
+ val updated = catalog.alterTable(testIdent,
+ TableChange.dropConstraint("missing_constraint", true, false))
+ assert(updated.constraints.length === constraints.length)
+ }
+
test("dropTable") {
val catalog = newCatalog()
assert(!catalog.tableExists(testIdent))
- catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ catalog.createTable(testIdent, tableInfo)
assert(catalog.tableExists(testIdent))
@@ -691,7 +992,12 @@ class CatalogSuite extends SparkFunSuite {
assert(!catalog.tableExists(testIdent))
assert(!catalog.tableExists(testIdentNew))
- catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ catalog.createTable(testIdent, tableInfo)
assert(catalog.tableExists(testIdent))
catalog.renameTable(testIdent, testIdentNew)
@@ -716,8 +1022,19 @@ class CatalogSuite extends SparkFunSuite {
assert(!catalog.tableExists(testIdent))
assert(!catalog.tableExists(testIdentNew))
- catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
- catalog.createTable(testIdentNew, columns, emptyTrans, emptyProps)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ catalog.createTable(testIdent, tableInfo)
+
+ val tableInfoNew = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ catalog.createTable(testIdentNew, tableInfoNew)
assert(catalog.tableExists(testIdent))
assert(catalog.tableExists(testIdentNew))
@@ -743,8 +1060,19 @@ class CatalogSuite extends SparkFunSuite {
val ident1 = Identifier.of(Array("ns1", "ns2"), "test_table_1")
val ident2 = Identifier.of(Array("ns1", "ns2"), "test_table_2")
- catalog.createTable(ident1, columns, emptyTrans, emptyProps)
- catalog.createTable(ident2, columns, emptyTrans, emptyProps)
+ val tableInfo1 = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ catalog.createTable(ident1, tableInfo1)
+
+ val tableInfo2 = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(emptyTrans)
+ .withProperties(emptyProps)
+ .build()
+ catalog.createTable(ident2, tableInfo2)
assert(catalog.listNamespaces === Array(Array("ns1")))
assert(catalog.listNamespaces(Array()) === Array(Array("ns1")))
@@ -939,13 +1267,12 @@ class CatalogSuite extends SparkFunSuite {
val partCatalog = new InMemoryPartitionTableCatalog
partCatalog.initialize("test", CaseInsensitiveStringMap.empty())
- val table = partCatalog.createTable(
- testIdent,
- Array(
- Column.create("col0", IntegerType),
- Column.create("part0", IntegerType)),
- Array[Transform](LogicalExpressions.identity(LogicalExpressions.parseReference("part0"))),
- util.Collections.emptyMap[String, String])
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(Array(Column.create("col0", IntegerType), Column.create("part0", IntegerType)))
+ .withPartitions(
+ Array[Transform](LogicalExpressions.identity(LogicalExpressions.parseReference("part0"))))
+ .withProperties(util.Collections.emptyMap[String, String]).build()
+ val table = partCatalog.createTable(testIdent, tableInfo)
val partTable = table.asInstanceOf[InMemoryPartitionTable]
val partIdent = InternalRow.apply(0)
val partIdent1 = InternalRow.apply(1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala
new file mode 100644
index 0000000000000..2d11bedb396fe
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
+import org.apache.spark.sql.connector.catalog.constraints.Constraint.ValidationStatus
+import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, LiteralValue, NamedReference}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.types.IntegerType
+
+class ConstraintSuite extends SparkFunSuite {
+
+ test("CHECK constraint toDDL") {
+ val con1 = Constraint.check("con1")
+ .predicateSql("id > 10")
+ .enforced(true)
+ .validationStatus(ValidationStatus.VALID)
+ .rely(true)
+ .build()
+ assert(con1.toDDL == "CONSTRAINT con1 CHECK (id > 10) ENFORCED VALID RELY")
+
+ val con2 = Constraint.check("con2")
+ .predicate(
+ new Predicate(
+ "=",
+ Array[Expression](
+ FieldReference(Seq("a", "b.c", "d")),
+ LiteralValue(1, IntegerType))))
+ .enforced(false)
+ .validationStatus(ValidationStatus.VALID)
+ .rely(true)
+ .build()
+ assert(con2.toDDL == "CONSTRAINT con2 CHECK (a.`b.c`.d = 1) NOT ENFORCED VALID RELY")
+
+ val con3 = Constraint.check("con3")
+ .predicateSql("a.b.c <=> 1")
+ .predicate(
+ new Predicate(
+ "<=>",
+ Array[Expression](
+ FieldReference(Seq("a", "b", "c")),
+ LiteralValue(1, IntegerType))))
+ .enforced(false)
+ .validationStatus(ValidationStatus.INVALID)
+ .rely(false)
+ .build()
+ assert(con3.toDDL == "CONSTRAINT con3 CHECK (a.b.c <=> 1) NOT ENFORCED INVALID NORELY")
+
+ val con4 = Constraint.check("con4").predicateSql("a = 1").build()
+ assert(con4.toDDL == "CONSTRAINT con4 CHECK (a = 1) ENFORCED UNVALIDATED NORELY")
+ }
+
+ test("UNIQUE constraint toDDL") {
+ val con1 = Constraint.unique(
+ "con1",
+ Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d"))))
+ .enforced(false)
+ .validationStatus(ValidationStatus.UNVALIDATED)
+ .rely(true)
+ .build()
+ assert(con1.toDDL == "CONSTRAINT con1 UNIQUE (a.b.c, d) NOT ENFORCED UNVALIDATED RELY")
+
+ val con2 = Constraint.unique(
+ "con2",
+ Array[NamedReference](FieldReference(Seq("a.b", "x", "y")), FieldReference(Seq("d"))))
+ .enforced(false)
+ .validationStatus(ValidationStatus.VALID)
+ .rely(true)
+ .build()
+ assert(con2.toDDL == "CONSTRAINT con2 UNIQUE (`a.b`.x.y, d) NOT ENFORCED VALID RELY")
+ }
+
+ test("PRIMARY KEY constraint toDDL") {
+ val pk1 = Constraint.primaryKey(
+ "pk1",
+ Array[NamedReference](FieldReference(Seq("a", "b", "c")), FieldReference(Seq("d"))))
+ .enforced(true)
+ .validationStatus(ValidationStatus.VALID)
+ .rely(true)
+ .build()
+ assert(pk1.toDDL == "CONSTRAINT pk1 PRIMARY KEY (a.b.c, d) ENFORCED VALID RELY")
+
+ val pk2 = Constraint.primaryKey(
+ "pk2",
+ Array[NamedReference](FieldReference(Seq("x.y", "z")), FieldReference(Seq("id"))))
+ .enforced(false)
+ .validationStatus(ValidationStatus.INVALID)
+ .rely(false)
+ .build()
+ assert(pk2.toDDL == "CONSTRAINT pk2 PRIMARY KEY (`x.y`.z, id) NOT ENFORCED INVALID NORELY")
+ }
+
+ test("FOREIGN KEY constraint toDDL") {
+ val fk1 = Constraint.foreignKey(
+ "fk1",
+ Array[NamedReference](FieldReference(Seq("col1")), FieldReference(Seq("col2"))),
+ Identifier.of(Array("schema"), "table"),
+ Array[NamedReference](FieldReference(Seq("ref_col1")), FieldReference(Seq("ref_col2"))))
+ .enforced(true)
+ .validationStatus(ValidationStatus.VALID)
+ .rely(true)
+ .build()
+ assert(fk1.toDDL == "CONSTRAINT fk1 FOREIGN KEY (col1, col2) " +
+ "REFERENCES schema.table (ref_col1, ref_col2) " +
+ "ENFORCED VALID RELY")
+
+ val fk2 = Constraint.foreignKey(
+ "fk2",
+ Array[NamedReference](FieldReference(Seq("x.y", "z"))),
+ Identifier.of(Array.empty[String], "other_table"),
+ Array[NamedReference](FieldReference(Seq("other_id"))))
+ .enforced(false)
+ .validationStatus(ValidationStatus.INVALID)
+ .rely(false)
+ .build()
+ assert(fk2.toDDL == "CONSTRAINT fk2 FOREIGN KEY (`x.y`.z) " +
+ "REFERENCES other_table (other_id) " +
+ "NOT ENFORCED INVALID NORELY")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 3ac8c3794b8ad..d6d397b94648d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -28,7 +28,7 @@ import com.google.common.base.Objects
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, MetadataStructFieldWithLogicalName}
-import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric}
@@ -141,7 +141,8 @@ abstract class InMemoryBaseTable(
schema: StructType,
row: InternalRow): (Any, DataType) = {
val index = schema.fieldIndex(fieldNames(0))
- val value = row.toSeq(schema).apply(index)
+ val field = schema(index)
+ val value = row.get(index, field.dataType)
if (fieldNames.length > 1) {
(value, schema(index).dataType) match {
case (row: InternalRow, nestedSchema: StructType) =>
@@ -400,18 +401,23 @@ abstract class InMemoryBaseTable(
val sizeInBytes = numRows * rowSizeInBytes
val numOfCols = tableSchema.fields.length
- val dataTypes = tableSchema.fields.map(_.dataType)
- val colValueSets = new Array[util.HashSet[Object]](numOfCols)
+ val colValueSets = new Array[util.HashSet[Any]](numOfCols)
val numOfNulls = new Array[Long](numOfCols)
for (i <- 0 until numOfCols) {
- colValueSets(i) = new util.HashSet[Object]
+ colValueSets(i) = new util.HashSet[Any]
}
inputPartitions.foreach(inputPartition =>
inputPartition.rows.foreach(row =>
for (i <- 0 until numOfCols) {
- colValueSets(i).add(row.get(i, dataTypes(i)))
- if (row.isNullAt(i)) {
+ val field = tableSchema(i)
+ val colValue = if (i < row.numFields) {
+ row.get(i, field.dataType)
+ } else {
+ ResolveDefaultColumns.getExistenceDefaultValue(field)
+ }
+ colValueSets(i).add(colValue)
+ if (colValue == null) {
numOfNulls(i) += 1
}
}
@@ -718,6 +724,11 @@ private class BufferedRowsReader(
schema: StructType,
row: InternalRow): Any = {
val index = schema.fieldIndex(field.name)
+
+ if (index >= row.numFields) {
+ return ResolveDefaultColumns.getExistenceDefaultValue(field)
+ }
+
field.dataType match {
case StructType(fields) =>
if (row.isNullAt(index)) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
index 437d7ffa63914..4f0588498ec4d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
@@ -30,11 +30,8 @@ class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog with Pro
protected val functions: util.Map[Identifier, UnboundFunction] =
new ConcurrentHashMap[Identifier, UnboundFunction]()
- override protected def allNamespaces: Seq[Seq[String]] = {
- (tables.keySet.asScala.map(_.namespace.toSeq) ++
- functions.keySet.asScala.map(_.namespace.toSeq) ++
- namespaces.keySet.asScala).toSeq.distinct
- }
+ override protected def allNamespaces: Seq[Seq[String]] =
+ (super.allNamespaces ++ functions.keySet.asScala.map(_.namespace.toSeq)).distinct
override def listFunctions(namespace: Array[String]): Array[Identifier] = {
if (namespace.isEmpty || namespaceExists(namespace)) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala
index 3b8020003aa4a..17f908370f76e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala
@@ -43,4 +43,8 @@ class InMemoryPartitionTableCatalog extends InMemoryTableCatalog {
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
}
+
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
index 98678289fa259..c822e27ceb58e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
@@ -21,6 +21,7 @@ import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
@@ -35,8 +36,10 @@ class InMemoryRowLevelOperationTable(
name: String,
schema: StructType,
partitioning: Array[Transform],
- properties: util.Map[String, String])
- extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations {
+ properties: util.Map[String, String],
+ constraints: Array[Constraint] = Array.empty)
+ extends InMemoryTable(name, schema, partitioning, properties, constraints)
+ with SupportsRowLevelOperations {
private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name)
private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
index deb200650bd52..29be5b19acd3c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
@@ -43,4 +43,37 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
}
+
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties)
+ }
+
+ override def alterTable(ident: Identifier, changes: TableChange*): Table = {
+ val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable]
+ val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes)
+ val schema = CatalogV2Util.applySchemaChanges(
+ table.schema,
+ changes,
+ tableProvider = Some("in-memory"),
+ statementType = "ALTER TABLE")
+ val partitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes)
+ val constraints = CatalogV2Util.collectConstraintChanges(table, changes)
+
+ // fail if the last column in the schema was dropped
+ if (schema.fields.isEmpty) {
+ throw new IllegalArgumentException(s"Cannot drop all fields")
+ }
+
+ val newTable = new InMemoryRowLevelOperationTable(
+ name = table.name,
+ schema = schema,
+ partitioning = partitioning,
+ properties = properties,
+ constraints = constraints)
+ newTable.withData(table.data)
+
+ tables.put(ident, newTable)
+
+ newTable
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index c27b8fea059f7..50e2449623e5c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.catalog
import java.util
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, WriteBuilder, WriterCommitMessage}
@@ -35,6 +36,7 @@ class InMemoryTable(
schema: StructType,
override val partitioning: Array[Transform],
override val properties: util.Map[String, String],
+ override val constraints: Array[Constraint] = Array.empty,
distribution: Distribution = Distributions.unspecified(),
ordering: Array[SortOrder] = Array.empty,
numPartitions: Option[Int] = None,
@@ -61,23 +63,31 @@ class InMemoryTable(
override def withData(
data: Array[BufferedRows],
- writeSchema: StructType): InMemoryTable = dataMap.synchronized {
- data.foreach(_.rows.foreach { row =>
- val key = getKey(row, writeSchema)
- dataMap += dataMap.get(key)
- .map { splits =>
- val newSplits = if (splits.last.rows.size >= numRowsPerSplit) {
- splits :+ new BufferedRows(key)
- } else {
- splits
+ writeSchema: StructType): InMemoryTable = {
+ dataMap.synchronized {
+ data.foreach(_.rows.foreach { row =>
+ val key = getKey(row, writeSchema)
+ dataMap += dataMap.get(key)
+ .map { splits =>
+ val newSplits = if (splits.last.rows.size >= numRowsPerSplit) {
+ splits :+ new BufferedRows(key)
+ } else {
+ splits
+ }
+ newSplits.last.withRow(row)
+ key -> newSplits
}
- newSplits.last.withRow(row)
- key -> newSplits
- }
- .getOrElse(key -> Seq(new BufferedRows(key).withRow(row)))
- addPartitionKey(key)
- })
- this
+ .getOrElse(key -> Seq(new BufferedRows(key).withRow(row)))
+ addPartitionKey(key)
+ })
+
+ if (data.exists(_.rows.exists(row => row.numFields == 1 &&
+ row.getInt(0) == InMemoryTable.uncommittableValue()))) {
+ throw new IllegalArgumentException(s"Test only mock write failure")
+ }
+
+ this
+ }
}
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
@@ -166,6 +176,8 @@ object InMemoryTable {
}
}
+ def uncommittableValue(): Int = Int.MaxValue / 2
+
private def splitAnd(filter: Filter): Seq[Filter] = {
filter match {
case And(left, right) => splitAnd(left) ++ splitAnd(right)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
index ae11cb9d69580..7d64cad2bb102 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
@@ -26,6 +26,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NonEmptyNamespaceException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
@@ -90,14 +91,20 @@ class BasicInMemoryTableCatalog extends TableCatalog {
}
override def createTable(
- ident: Identifier,
- columns: Array[Column],
- partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ ident: Identifier,
+ columns: Array[Column],
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
createTable(ident, columns, partitions, properties, Distributions.unspecified(),
Array.empty, None, None)
}
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties(),
+ Distributions.unspecified(), Array.empty, None, None, tableInfo.constraints())
+ }
+
+ // scalastyle:off argcount
def createTable(
ident: Identifier,
columns: Array[Column],
@@ -107,8 +114,10 @@ class BasicInMemoryTableCatalog extends TableCatalog {
ordering: Array[SortOrder],
requiredNumPartitions: Option[Int],
advisoryPartitionSize: Option[Long],
+ constraints: Array[Constraint] = Array.empty,
distributionStrictlyRequired: Boolean = true,
numRowsPerSplit: Int = Int.MaxValue): Table = {
+ // scalastyle:on argcount
val schema = CatalogV2Util.v2ColumnsToStructType(columns)
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
@@ -117,9 +126,9 @@ class BasicInMemoryTableCatalog extends TableCatalog {
InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
val tableName = s"$name.${ident.quoted}"
- val table = new InMemoryTable(tableName, schema, partitions, properties, distribution,
- ordering, requiredNumPartitions, advisoryPartitionSize, distributionStrictlyRequired,
- numRowsPerSplit)
+ val table = new InMemoryTable(tableName, schema, partitions, properties, constraints,
+ distribution, ordering, requiredNumPartitions, advisoryPartitionSize,
+ distributionStrictlyRequired, numRowsPerSplit)
tables.put(ident, table)
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
@@ -128,15 +137,25 @@ class BasicInMemoryTableCatalog extends TableCatalog {
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
val table = loadTable(ident).asInstanceOf[InMemoryTable]
val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes)
- val schema = CatalogV2Util.applySchemaChanges(table.schema, changes, None, "ALTER TABLE")
+ val schema = CatalogV2Util.applySchemaChanges(
+ table.schema,
+ changes,
+ tableProvider = Some("in-memory"),
+ statementType = "ALTER TABLE")
val finalPartitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes)
+ val constraints = CatalogV2Util.collectConstraintChanges(table, changes)
// fail if the last column in the schema was dropped
if (schema.fields.isEmpty) {
throw new IllegalArgumentException(s"Cannot drop all fields")
}
- val newTable = new InMemoryTable(table.name, schema, finalPartitioning, properties)
+ val newTable = new InMemoryTable(
+ name = table.name,
+ schema = schema,
+ partitioning = finalPartitioning,
+ properties = properties,
+ constraints = constraints)
.withData(table.data)
tables.put(ident, newTable)
@@ -174,6 +193,7 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp
override def capabilities: java.util.Set[TableCatalogCapability] = {
Set(
TableCatalogCapability.SUPPORT_COLUMN_DEFAULT_VALUE,
+ TableCatalogCapability.SUPPORT_TABLE_CONSTRAINT,
TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_GENERATED_COLUMNS,
TableCatalogCapability.SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS
).asJava
@@ -184,7 +204,12 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp
procedures.put(Identifier.of(Array("dummy"), "increment"), UnboundIncrement)
protected def allNamespaces: Seq[Seq[String]] = {
- (tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct
+ (tables.keySet.asScala.map(_.namespace.toSeq)
+ ++ namespaces.keySet.asScala
+ ++ procedures.keySet.asScala
+ .filter(i => !i.namespace.sameElements(Array("dummy")))
+ .map(_.namespace.toSeq)
+ ).toSeq.distinct
}
override def namespaceExists(namespace: Array[String]): Boolean = {
@@ -268,6 +293,17 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp
procedure
}
+ override def listProcedures(namespace: Array[String]): Array[Identifier] = {
+ val result =
+ if (namespaceExists(namespace)) {
+ procedures.keySet.asScala
+ .filter(_.namespace.sameElements(namespace))
+ } else {
+ throw new NoSuchNamespaceException(namespace)
+ }
+ result.filter(!_.namespace.sameElements(Array("dummy"))).toArray
+ }
+
object UnboundIncrement extends UnboundProcedure {
override def name: String = "dummy_increment"
override def description: String = "test method to increment an in-memory counter"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala
index 7ec1cab304ade..861badd390798 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala
@@ -43,4 +43,8 @@ class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog {
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
}
+
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala
index 2a207901b83f5..ee2400cab35c8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala
@@ -31,40 +31,28 @@ class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTable
import InMemoryTableCatalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
- override def stageCreate(
- ident: Identifier,
- columns: Array[Column],
- partitions: Array[Transform],
- properties: util.Map[String, String]): StagedTable = {
- validateStagedTable(partitions, properties)
+ override def stageCreate(ident: Identifier, tableInfo: TableInfo): StagedTable = {
+ validateStagedTable(tableInfo.partitions, tableInfo.properties)
new TestStagedCreateTable(
ident,
new InMemoryTable(s"$name.${ident.quoted}",
- CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties))
+ tableInfo.schema(), tableInfo.partitions(), tableInfo.properties()))
}
- override def stageReplace(
- ident: Identifier,
- columns: Array[Column],
- partitions: Array[Transform],
- properties: util.Map[String, String]): StagedTable = {
- validateStagedTable(partitions, properties)
+ override def stageReplace(ident: Identifier, tableInfo: TableInfo): StagedTable = {
+ validateStagedTable(tableInfo.partitions, tableInfo.properties)
new TestStagedReplaceTable(
ident,
new InMemoryTable(s"$name.${ident.quoted}",
- CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties))
+ tableInfo.schema(), tableInfo.partitions(), tableInfo.properties()))
}
- override def stageCreateOrReplace(
- ident: Identifier,
- columns: Array[Column],
- partitions: Array[Transform],
- properties: util.Map[String, String]): StagedTable = {
- validateStagedTable(partitions, properties)
+ override def stageCreateOrReplace(ident: Identifier, tableInfo: TableInfo) : StagedTable = {
+ validateStagedTable(tableInfo.partitions, tableInfo.properties)
new TestStagedCreateOrReplaceTable(
ident,
new InMemoryTable(s"$name.${ident.quoted}",
- CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties))
+ tableInfo.schema(), tableInfo.partitions(), tableInfo.properties))
}
private def validateStagedTable(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 397241be76eb1..115d561cbe7b8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -22,7 +22,7 @@ import org.json4s.jackson.JsonMethods
import org.apache.spark.{SparkException, SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution}
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CollationFactory, StringConcat}
import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes}
@@ -380,6 +380,8 @@ class DataTypeSuite extends SparkFunSuite {
checkDefaultSize(VarcharType(10), 10)
yearMonthIntervalTypes.foreach(checkDefaultSize(_, 4))
dayTimeIntervalTypes.foreach(checkDefaultSize(_, 8))
+ checkDefaultSize(TimeType(TimeType.MIN_PRECISION), 8)
+ checkDefaultSize(TimeType(TimeType.MAX_PRECISION), 8)
def checkEqualsIgnoreCompatibleNullability(
from: DataType,
@@ -1371,4 +1373,45 @@ class DataTypeSuite extends SparkFunSuite {
}
assert(exception.getMessage.contains("The length of varchar type cannot be negative."))
}
+
+ test("precisions of the TIME data type") {
+ TimeType.MIN_PRECISION to TimeType.MAX_PRECISION foreach { p =>
+ assert(TimeType(p).sql == s"TIME($p)")
+ }
+
+ Seq(
+ Int.MinValue,
+ TimeType.MIN_PRECISION - 1,
+ TimeType.MAX_PRECISION + 1,
+ Int.MaxValue).foreach { p =>
+ checkError(
+ exception = intercept[SparkException] {
+ TimeType(p)
+ },
+ condition = "UNSUPPORTED_TIME_PRECISION",
+ parameters = Map("precision" -> p.toString)
+ )
+ }
+ }
+
+ test("Parse time(n) as TimeType(n)") {
+ 0 to 6 foreach { n =>
+ assert(DataType.fromJson(s"\"time($n)\"") == TimeType(n))
+ val expectedStructType = StructType(Seq(StructField("t", TimeType(n))))
+ assert(DataType.fromDDL(s"t time($n)") == expectedStructType)
+ }
+
+ checkError(
+ exception = intercept[SparkIllegalArgumentException] {
+ DataType.fromJson("\"time(9)\"")
+ },
+ condition = "INVALID_JSON_DATA_TYPE",
+ parameters = Map("invalidType" -> "time(9)"))
+ checkError(
+ exception = intercept[ParseException] {
+ DataType.fromDDL("t time(-1)")
+ },
+ condition = "PARSE_SYNTAX_ERROR",
+ parameters = Map("error" -> "'time'", "hint" -> ""))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
index b2f3adac68e13..fc011647fc8da 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
@@ -69,6 +69,10 @@ object DataTypeTestUtils {
YearMonthIntervalType(YEAR),
YearMonthIntervalType(MONTH))
+ val timeTypes: Seq[TimeType] = Seq(
+ TimeType(TimeType.MIN_PRECISION),
+ TimeType(TimeType.MAX_PRECISION))
+
val unsafeRowMutableFieldTypes: Seq[DataType] = Seq(
NullType,
BooleanType,
@@ -97,7 +101,7 @@ object DataTypeTestUtils {
TimestampNTZType,
DateType,
StringType,
- BinaryType) ++ dayTimeIntervalTypes ++ yearMonthIntervalTypes
+ BinaryType) ++ dayTimeIntervalTypes ++ yearMonthIntervalTypes ++ timeTypes
/**
* All the types that we can use in a property check
@@ -113,7 +117,7 @@ object DataTypeTestUtils {
DateType,
StringType,
TimestampType,
- TimestampNTZType) ++ dayTimeIntervalTypes ++ yearMonthIntervalTypes
+ TimestampNTZType) ++ dayTimeIntervalTypes ++ yearMonthIntervalTypes ++ timeTypes
/**
* Instances of [[ArrayType]] for all [[AtomicType]]s. Arrays of these types may contain null.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 821d59796753b..4af2c8367819b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -382,7 +382,8 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
parameters = Map(
"expression" -> "'str'",
"sourceType" -> "\"STRING\"",
- "targetType" -> "\"DECIMAL(10,0)\""))
+ "targetType" -> "\"DECIMAL(10,0)\"",
+ "ansiConfig" -> "\"spark.sql.ansi.enabled\""))
}
test("SPARK-35841: Casting string to decimal type doesn't work " +
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
index ad0adf13643aa..5dd45d3d4496c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.StructType.fromDDL
import org.apache.spark.sql.types.YearMonthIntervalType._
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.UTF8String.LongWrapper
class StructTypeSuite extends SparkFunSuite with SQLHelper {
@@ -835,18 +836,33 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper {
test("SPARK-51119: Add fallback to process unresolved EXISTS_DEFAULT") {
val source = StructType(
Array(
- StructField("c1", VariantType, true,
- new MetadataBuilder()
- .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, "parse_json(null)")
- .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "parse_json(null)")
- .build()),
- StructField("c0", StringType, true,
- new MetadataBuilder()
- .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, "current_catalog()")
- .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "current_catalog()")
- .build())))
+ StructField("c0", VariantType, true,
+ new MetadataBuilder()
+ .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY,
+ "parse_json(null)")
+ .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY,
+ "parse_json(null)")
+ .build()),
+ StructField("c1", StringType, true,
+ new MetadataBuilder()
+ .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY,
+ "current_catalog()")
+ .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY,
+ "current_catalog()")
+ .build()),
+ StructField("c2", StringType, true,
+ new MetadataBuilder()
+ .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY,
+ "CAST(CURRENT_TIMESTAMP AS BIGINT)")
+ .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY,
+ "CAST(CURRENT_TIMESTAMP AS BIGINT)")
+ .build())))
val res = ResolveDefaultColumns.existenceDefaultValues(source)
assert(res(0) == null)
assert(res(1) == UTF8String.fromString("spark_catalog"))
+
+ val res2Wrapper = new LongWrapper
+ assert(res(2).asInstanceOf[UTF8String].toLong(res2Wrapper))
+ assert(res2Wrapper.value > 0)
}
}
diff --git a/sql/connect/client/jvm/pom.xml b/sql/connect/client/jvm/pom.xml
index 81e195fb5ccde..ee4f7b3483e61 100644
--- a/sql/connect/client/jvm/pom.xml
+++ b/sql/connect/client/jvm/pom.xml
@@ -138,12 +138,19 @@
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
+
+ org.scalatest
+ scalatest-maven-plugin
+
+ -ea -Xmx4g -Xss4m -XX:MaxMetaspaceSize=2g -XX:ReservedCodeCacheSize=${CodeCacheSize} ${extraJavaTestArgs} -Darrow.memory.debug.allocator=true
+
+ org.apache.maven.pluginsmaven-shade-plugin
-
+ falsetrue
@@ -261,24 +268,6 @@
-
- org.codehaus.mojo
- build-helper-maven-plugin
-
-
- add-sources
- generate-sources
-
- add-source
-
-
-
- src/main/scala-${scala.binary.version}
-
-
-
-
-
-
\ No newline at end of file
+
diff --git a/sql/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/sql/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
index 907105e370c08..afb046af98594 100644
--- a/sql/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
+++ b/sql/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
@@ -28,6 +28,7 @@
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.RowFactory.create;
import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.connect.test.IntegrationTestUtils;
import org.apache.spark.sql.connect.test.SparkConnectServerUtils;
import org.apache.spark.sql.types.StructType;
@@ -39,14 +40,18 @@ public class JavaEncoderSuite implements Serializable {
@BeforeAll
public static void setup() {
+ Assumptions.assumeTrue(IntegrationTestUtils.isAssemblyJarsDirExists(),
+ "Skipping all tests because assembly jars directory does not exist.");
spark = SparkConnectServerUtils.createSparkSession();
}
@AfterAll
public static void tearDown() {
- spark.stop();
- spark = null;
- SparkConnectServerUtils.stop();
+ if (spark != null) {
+ spark.stop();
+ spark = null;
+ SparkConnectServerUtils.stop();
+ }
}
private static BigDecimal bigDec(long unscaled, int scale) {
diff --git a/sql/connect/client/jvm/src/test/resources/log4j2.properties b/sql/connect/client/jvm/src/test/resources/log4j2.properties
index 550fd261b6fb5..47b6e39eb020f 100644
--- a/sql/connect/client/jvm/src/test/resources/log4j2.properties
+++ b/sql/connect/client/jvm/src/test/resources/log4j2.properties
@@ -37,3 +37,6 @@ appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %maxLen{%m}{512}%n%ex{
# Ignore messages below warning level from Jetty, because it's a bit verbose
logger.jetty.name = org.sparkproject.jetty
logger.jetty.level = warn
+
+logger.allocator.name = org.apache.arrow.memory.BaseAllocator
+logger.allocator.level = trace
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index ba77879a5a800..9e93cd4442d5e 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -45,10 +45,13 @@ class DataFrameSubquerySuite extends QueryTest with RemoteSparkSession {
row((null, 5.0)),
row((6, null))).toDF("c", "d")
+ lazy val t = r.filter($"c".isNotNull && $"d".isNotNull)
+
override def beforeAll(): Unit = {
super.beforeAll()
l.createOrReplaceTempView("l")
r.createOrReplaceTempView("r")
+ t.createOrReplaceTempView("t")
}
test("noop outer()") {
@@ -318,6 +321,128 @@ class DataFrameSubquerySuite extends QueryTest with RemoteSparkSession {
sql("select a, (select sum(d) from r where c = a) from l"))
}
+ test("IN predicate subquery") {
+ checkAnswer(
+ spark.table("l").where($"l.a".isin(spark.table("r").select($"c"))),
+ sql("select * from l where l.a in (select c from r)"))
+
+ checkAnswer(
+ spark
+ .table("l")
+ .where($"l.a".isin(spark.table("r").where($"l.b".outer() < $"r.d").select($"c"))),
+ sql("select * from l where l.a in (select c from r where l.b < r.d)"))
+
+ checkAnswer(
+ spark
+ .table("l")
+ .where($"l.a".isin(spark.table("r").select("c")) && $"l.a" > 2 && $"l.b".isNotNull),
+ sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null"))
+ }
+
+ test("IN predicate subquery with struct") {
+ withTempView("ll", "rr") {
+ spark.table("l").select($"*", struct("a", "b").alias("sab")).createOrReplaceTempView("ll")
+ spark
+ .table("r")
+ .select($"*", struct($"c".as("a"), $"d".as("b")).alias("scd"))
+ .createOrReplaceTempView("rr")
+
+ for ((col, values) <- Seq(
+ ($"sab", "sab"),
+ (struct(struct($"a", $"b")), "struct(struct(a, b))"));
+ (df, query) <- Seq(
+ (spark.table("rr").select($"scd"), "select scd from rr"),
+ (
+ spark.table("rr").select(struct($"c".as("a"), $"d".as("b"))),
+ "select struct(c as a, d as b) from rr"),
+ (spark.table("rr").select(struct($"c", $"d")), "select struct(c, d) from rr"))) {
+ checkAnswer(
+ spark.table("ll").where(col.isin(df)).select($"a", $"b"),
+ sql(s"select a, b from ll where $values in ($query)"))
+ }
+ }
+ }
+
+ test("NOT IN predicate subquery") {
+ checkAnswer(
+ spark.table("l").where(!$"a".isin(spark.table("r").select($"c"))),
+ sql("select * from l where a not in (select c from r)"))
+
+ checkAnswer(
+ spark.table("l").where(!$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))),
+ sql("select * from l where a not in (select c from r where c is not null)"))
+
+ checkAnswer(
+ spark
+ .table("l")
+ .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d")) && $"a" < lit(4)),
+ sql("select * from l where (a, b) not in (select c, d from t) and a < 4"))
+
+ // Empty sub-query
+ checkAnswer(
+ spark
+ .table("l")
+ .where(
+ !struct($"a", $"b").isin(spark.table("r").where($"c" > lit(10)).select($"c", $"d"))),
+ sql("select * from l where (a, b) not in (select c, d from r where c > 10)"))
+ }
+
+ test("IN predicate subquery within OR") {
+ checkAnswer(
+ spark
+ .table("l")
+ .where($"l.a".isin(spark.table("r").select("c"))
+ || $"l.a".isin(spark.table("r").where($"l.b".outer() < $"r.d").select($"c"))),
+ sql(
+ "select * from l where l.a in (select c from r)" +
+ " or l.a in (select c from r where l.b < r.d)"))
+
+ checkAnswer(
+ spark
+ .table("l")
+ .where(!$"a".isin(spark.table("r").select("c"))
+ || !$"a".isin(spark.table("r").where($"c".isNotNull).select($"c"))),
+ sql(
+ "select * from l where a not in (select c from r)" +
+ " or a not in (select c from r where c is not null)"))
+ }
+
+ test("complex IN predicate subquery") {
+ checkAnswer(
+ spark.table("l").where(!struct($"a", $"b").isin(spark.table("r").select($"c", $"d"))),
+ sql("select * from l where (a, b) not in (select c, d from r)"))
+
+ checkAnswer(
+ spark
+ .table("l")
+ .where(!struct($"a", $"b").isin(spark.table("t").select($"c", $"d"))
+ && ($"a" + $"b").isNotNull),
+ sql("select * from l where (a, b) not in (select c, d from t) and (a + b) is not null"))
+ }
+
+ test("same column in subquery and outer table") {
+ checkAnswer(
+ spark
+ .table("l")
+ .as("l1")
+ .where(
+ $"a".isin(
+ spark
+ .table("l")
+ .where($"a" < lit(3))
+ .groupBy($"a")
+ .agg(Map.empty[String, String])))
+ .select($"a"),
+ sql("select a from l l1 where a in (select a from l where a < 3 group by a)"))
+ }
+
+ test("col IN (NULL)") {
+ checkAnswer(spark.table("l").where($"a".isin(null)), sql("SELECT * FROM l WHERE a IN (NULL)"))
+ checkAnswer(
+ spark.table("l").where(!$"a".isin(null)),
+ sql("SELECT * FROM l WHERE a NOT IN (NULL)"))
+ }
+
private def table1() = {
sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
spark.table("t1")
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 816050f441781..c07b624e8f8fe 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -176,7 +176,7 @@ class ReplE2ESuite extends ConnectFunSuite with RemoteSparkSession with BeforeAn
.get(s"$sparkHome/sql/connect/client/jvm/src/test/resources/TestHelloV2_$scalaVersion.jar")
.toFile
- assert(testJar.exists(), "Missing TestHelloV2 jar!")
+ assume(testJar.exists(), "Missing TestHelloV2 jar!")
val input = s"""
|import java.nio.file.Paths
|def classLoadingTest(x: Int): Int = {
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index 19263cdbed6d5..566584f7e7a8f 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -1058,7 +1058,12 @@ class ClientE2ETestSuite
private def checkSameResult[E](expected: scala.collection.Seq[E], dataset: Dataset[E]): Unit = {
dataset.withResult { result =>
- assert(expected === result.iterator.toBuffer)
+ val iter = result.iterator
+ try {
+ assert(expected === iter.toBuffer)
+ } finally {
+ iter.close()
+ }
}
}
@@ -1233,9 +1238,9 @@ class ClientE2ETestSuite
.filter("id > 5 and id < 9")
df.withResult { result =>
+ // build and verify the destructive iterator
+ val iterator = result.destructiveIterator
try {
- // build and verify the destructive iterator
- val iterator = result.destructiveIterator
// resultMap Map is empty before traversing the result iterator
assertResultsMapEmpty(result)
val buffer = mutable.Set.empty[Long]
@@ -1250,7 +1255,7 @@ class ClientE2ETestSuite
val expectedResult = Set(6L, 7L, 8L)
assert(buffer.size === 3 && expectedResult == buffer)
} finally {
- result.close()
+ iterator.close()
}
}
}
@@ -1565,12 +1570,12 @@ class ClientE2ETestSuite
val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema))
val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema))
- val obMetrics = observedDf.collectResult().getObservedMetrics
- assert(df.collectResult().getObservedMetrics === Map.empty)
- assert(observedDf.collectResult().getObservedMetrics === ob1Metrics)
+ val obMetrics = observedDf.withResult(_.getObservedMetrics)
+ assert(df.withResult(_.getObservedMetrics) === Map.empty)
+ assert(observedDf.withResult(_.getObservedMetrics) === ob1Metrics)
assert(obMetrics.map(_._2.schema) === Seq(ob1Schema))
- val obObMetrics = observedObservedDf.collectResult().getObservedMetrics
+ val obObMetrics = observedObservedDf.withResult(_.getObservedMetrics)
assert(obObMetrics === ob1Metrics ++ ob2Metrics)
assert(obObMetrics.map(_._2.schema).exists(_.equals(ob1Schema)))
assert(obObMetrics.map(_._2.schema).exists(_.equals(ob2Schema)))
@@ -1579,7 +1584,7 @@ class ClientE2ETestSuite
for (collectFunc <- Seq(
("collect", (df: DataFrame) => df.collect()),
("collectAsList", (df: DataFrame) => df.collectAsList()),
- ("collectResult", (df: DataFrame) => df.collectResult().length),
+ ("collectResult", (df: DataFrame) => df.withResult(_.length)),
("write", (df: DataFrame) => df.write.format("noop").mode("append").save())))
test(
"Observation.get is blocked until the query is finished, " +
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala
index e28d8587a4191..1d022489b701b 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala
@@ -308,34 +308,36 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
}
test("progress is available for the spark result") {
- val result = spark
+ spark
.range(10000)
.repartition(1000)
- .collectResult()
- assert(result.length == 10000)
- assert(result.progress.stages.map(_.numTasks).sum > 100)
- assert(result.progress.stages.map(_.completedTasks).sum > 100)
+ .withResult { result =>
+ assert(result.length == 10000)
+ assert(result.progress.stages.map(_.numTasks).sum > 100)
+ assert(result.progress.stages.map(_.completedTasks).sum > 100)
+ }
}
test("interrupt operation") {
val session = spark
import session.implicits._
- val result = spark
+ spark
.range(10)
.map(n => {
Thread.sleep(5000); n
})
- .collectResult()
- // cancel
- val operationId = result.operationId
- val canceledId = spark.interruptOperation(operationId)
- assert(canceledId == Seq(operationId))
- // and check that it got canceled
- val e = intercept[SparkException] {
- result.toArray
- }
- assert(e.getMessage contains "OPERATION_CANCELED")
+ .withResult { result =>
+ // cancel
+ val operationId = result.operationId
+ val canceledId = spark.interruptOperation(operationId)
+ assert(canceledId == Seq(operationId))
+ // and check that it got canceled
+ val e = intercept[SparkException] {
+ result.toArray
+ }
+ assert(e.getMessage contains "OPERATION_CANCELED")
+ }
}
test("option propagation") {
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala
index b50442de31f04..42fc0ccfed721 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/UDFClassLoadingE2ESuite.scala
@@ -37,8 +37,8 @@ class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession {
// See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created.
private val udfByteArray: Array[Byte] =
Files.readAllBytes(Paths.get(s"src/test/resources/udf$scalaVersion"))
- private val udfJar =
- new File(s"src/test/resources/udf$scalaVersion.jar").toURI.toURL
+ private val udfJarFile = new File(s"src/test/resources/udf$scalaVersion.jar")
+ private lazy val udfJar = udfJarFile.toURI.toURL
private def registerUdf(session: SparkSession): Unit = {
val builder = proto.CommonInlineUserDefinedFunction
@@ -55,6 +55,7 @@ class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession {
}
test("update class loader after stubbing: new session") {
+ assume(udfJarFile.exists)
// Session1 should stub the missing class, but fail to call methods on it
val session1 = spark.newSession()
@@ -71,6 +72,7 @@ class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession {
}
test("update class loader after stubbing: same session") {
+ assume(udfJarFile.exists)
// Session should stub the missing class, but fail to call methods on it
val session = spark.newSession()
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
index fb35812233562..dcf3b91fece2a 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
@@ -118,6 +118,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
private def singleChunkArtifactTest(path: String): Unit = {
test(s"Single Chunk Artifact - $path") {
val artifactPath = artifactFilePath.resolve(path)
+ assume(artifactPath.toFile.exists)
artifactManager.addArtifact(artifactPath.toString)
val receivedRequests = service.getAndClearLatestAddArtifactRequests()
@@ -179,6 +180,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
test("Chunked Artifact - junitLargeJar.jar") {
val artifactPath = artifactFilePath.resolve("junitLargeJar.jar")
+ assume(artifactPath.toFile.exists)
artifactManager.addArtifact(artifactPath.toString)
// Expected chunks = roundUp( file_size / chunk_size) = 12
// File size of `junitLargeJar.jar` is 384581 bytes.
@@ -197,8 +199,12 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
}
test("Batched SingleChunkArtifacts") {
- val file1 = artifactFilePath.resolve("smallClassFile.class").toUri
- val file2 = artifactFilePath.resolve("smallJar.jar").toUri
+ val path1 = artifactFilePath.resolve("smallClassFile.class")
+ assume(path1.toFile.exists)
+ val file1 = path1.toUri
+ val path2 = artifactFilePath.resolve("smallJar.jar")
+ assume(path2.toFile.exists)
+ val file2 = path2.toUri
artifactManager.addArtifacts(Seq(file1, file2))
val receivedRequests = service.getAndClearLatestAddArtifactRequests()
// Single request containing 2 artifacts.
@@ -219,10 +225,18 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
}
test("Mix of SingleChunkArtifact and chunked artifact") {
- val file1 = artifactFilePath.resolve("smallClassFile.class").toUri
- val file2 = artifactFilePath.resolve("junitLargeJar.jar").toUri
- val file3 = artifactFilePath.resolve("smallClassFileDup.class").toUri
- val file4 = artifactFilePath.resolve("smallJar.jar").toUri
+ val path1 = artifactFilePath.resolve("smallClassFile.class")
+ assume(path1.toFile.exists)
+ val file1 = path1.toUri
+ val path2 = artifactFilePath.resolve("junitLargeJar.jar")
+ assume(path2.toFile.exists)
+ val file2 = path2.toUri
+ val path3 = artifactFilePath.resolve("smallClassFileDup.class")
+ assume(path3.toFile.exists)
+ val file3 = path3.toUri
+ val path4 = artifactFilePath.resolve("smallJar.jar")
+ assume(path4.toFile.exists)
+ val file4 = path4.toUri
artifactManager.addArtifacts(Seq(file1, file2, file3, file4))
val receivedRequests = service.getAndClearLatestAddArtifactRequests()
// There are a total of 14 requests.
@@ -290,6 +304,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
test("artifact with custom target") {
val artifactPath = artifactFilePath.resolve("smallClassFile.class")
+ assume(artifactPath.toFile.exists)
val target = "sub/package/smallClassFile.class"
artifactManager.addArtifact(artifactPath.toString, target)
val receivedRequests = service.getAndClearLatestAddArtifactRequests()
@@ -310,6 +325,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
test("in-memory artifact with custom target") {
val artifactPath = artifactFilePath.resolve("smallClassFile.class")
+ assume(artifactPath.toFile.exists)
val artifactBytes = Files.readAllBytes(artifactPath)
val target = "sub/package/smallClassFile.class"
artifactManager.addArtifact(artifactBytes, target)
@@ -333,6 +349,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
"When both source and target paths are given, extension conditions are checked " +
"on target path") {
val artifactPath = artifactFilePath.resolve("smallClassFile.class")
+ assume(artifactPath.toFile.exists)
assertThrows[UnsupportedOperationException] {
artifactManager.addArtifact(artifactPath.toString, "dummy.extension")
}
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala
index 2f8332878bbf5..92cd1acd45d40 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala
@@ -28,15 +28,15 @@ class ClassFinderSuite extends ConnectFunSuite {
private val classResourcePath = commonResourcePath.resolve("artifact-tests")
test("REPLClassDirMonitor functionality test") {
+ val requiredClasses = Seq("Hello.class", "smallClassFile.class", "smallClassFileDup.class")
+ requiredClasses.foreach(className =>
+ assume(classResourcePath.resolve(className).toFile.exists))
val copyDir = SparkFileUtils.createTempDir().toPath
FileUtils.copyDirectory(classResourcePath.toFile, copyDir.toFile)
val monitor = new REPLClassDirMonitor(copyDir.toAbsolutePath.toString)
def checkClasses(monitor: REPLClassDirMonitor, additionalClasses: Seq[String] = Nil): Unit = {
- val expectedClassFiles = (Seq(
- "Hello.class",
- "smallClassFile.class",
- "smallClassFileDup.class") ++ additionalClasses).map(name => Paths.get(name))
+ val expectedClassFiles = (requiredClasses ++ additionalClasses).map(name => Paths.get(name))
val foundArtifacts = monitor.findClasses().toSeq
assert(expectedClassFiles.forall { classPath =>
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
index a3c0220665324..ad0f880e8f57a 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
@@ -99,7 +99,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite {
assert(!builder.sslEnabled)
assert(builder.token.isEmpty)
assert(builder.userId.contains("Q12"))
- assert(builder.userName.isEmpty)
+ assert(builder.userName.contains(System.getProperty("user.name", null)))
assert(builder.options.isEmpty)
}
{
@@ -116,7 +116,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite {
assert(builder.userAgent.contains("_SPARK_CONNECT_SCALA"))
assert(builder.sslEnabled)
assert(builder.token.isEmpty)
- assert(builder.userId.isEmpty)
+ assert(builder.userId.contains(System.getProperty("user.name", null)))
assert(builder.userName.contains("Nico"))
assert(builder.options === Map(("mode", "turbo"), ("cluster", "mycl")))
}
@@ -127,8 +127,8 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite {
assert(builder.userAgent.contains("_SPARK_CONNECT_SCALA"))
assert(!builder.sslEnabled)
assert(builder.token.contains("thisismysecret"))
- assert(builder.userId.isEmpty)
- assert(builder.userName.isEmpty)
+ assert(builder.userId.contains(System.getProperty("user.name", null)))
+ assert(builder.userName.contains(System.getProperty("user.name", null)))
assert(builder.options.isEmpty)
}
}
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index f6d4a3b587a91..9bb8f5889d330 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -70,6 +70,11 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
}
}
+ test("SPARK-51391: Use 'user.name' by default") {
+ client = SparkConnectClient.builder().build()
+ assert(client.userId == System.getProperty("user.name"))
+ }
+
test("Placeholder test: Create SparkConnectClient") {
client = SparkConnectClient.builder().userId("abc123").build()
assert(client.userId == "abc123")
@@ -482,7 +487,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
val session = SparkSession.builder().client(client).create()
val artifactFilePath = commonResourcePath.resolve("artifact-tests")
- session.addArtifact(artifactFilePath.resolve("smallClassFile.class").toString)
+ val path = artifactFilePath.resolve("smallClassFile.class")
+ assume(path.toFile.exists)
+ session.addArtifact(path.toString)
}
private def buildPlan(query: String): proto.Plan = {
@@ -630,6 +637,26 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
observer.onNext(proto.AddArtifactsRequest.newBuilder().build())
observer.onCompleted()
}
+
+ test("client can set a custom operation id for ExecutePlan requests") {
+ startDummyServer(0)
+ client = SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:${server.getPort}")
+ .enableReattachableExecute()
+ .build()
+
+ val plan = buildPlan("select * from range(10000000)")
+ val dummyUUID = "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
+ val iter = client.execute(plan, operationId = Some(dummyUUID))
+ val reattachableIter =
+ ExecutePlanResponseReattachableIterator.fromIterator(iter)
+ assert(reattachableIter.operationId == dummyUUID)
+ while (reattachableIter.hasNext) {
+ val resp = reattachableIter.next()
+ assert(resp.getOperationId == dummyUUID)
+ }
+ }
}
class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 58e19389cae2e..75816a835aaa7 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -99,40 +99,49 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
val serializerAllocator = newAllocator("serialization")
val deserializerAllocator = newAllocator("deserialization")
- val arrowIterator = ArrowSerializer.serialize(
- input = iterator,
- enc = inputEncoder,
- allocator = serializerAllocator,
- maxRecordsPerBatch = maxRecordsPerBatch,
- maxBatchSize = maxBatchSize,
- batchSizeCheckInterval = batchSizeCheckInterval,
- timeZoneId = "UTC",
- largeVarTypes = false)
+ try {
+ val arrowIterator = ArrowSerializer.serialize(
+ input = iterator,
+ enc = inputEncoder,
+ allocator = serializerAllocator,
+ maxRecordsPerBatch = maxRecordsPerBatch,
+ maxBatchSize = maxBatchSize,
+ batchSizeCheckInterval = batchSizeCheckInterval,
+ timeZoneId = "UTC",
+ largeVarTypes = false)
- val inspectedIterator = if (inspectBatch != null) {
- arrowIterator.map { batch =>
- inspectBatch(batch)
- batch
+ val inspectedIterator = if (inspectBatch != null) {
+ arrowIterator.map { batch =>
+ inspectBatch(batch)
+ batch
+ }
+ } else {
+ arrowIterator
}
- } else {
- arrowIterator
- }
- val resultIterator =
- ArrowDeserializers.deserializeFromArrow(
- inspectedIterator,
- outputEncoder,
- deserializerAllocator,
- timeZoneId = "UTC")
- new CloseableIterator[O] {
- override def close(): Unit = {
- arrowIterator.close()
- resultIterator.close()
+ val resultIterator =
+ ArrowDeserializers.deserializeFromArrow(
+ inspectedIterator,
+ outputEncoder,
+ deserializerAllocator,
+ timeZoneId = "UTC")
+ new CloseableIterator[O] {
+ override def close(): Unit = {
+ arrowIterator.close()
+ resultIterator.close()
+ serializerAllocator.close()
+ deserializerAllocator.close()
+ }
+
+ override def hasNext: Boolean = resultIterator.hasNext
+
+ override def next(): O = resultIterator.next()
+ }
+ } catch {
+ case e: Throwable =>
serializerAllocator.close()
deserializerAllocator.close()
- }
- override def hasNext: Boolean = resultIterator.hasNext
- override def next(): O = resultIterator.next()
+ throw e
}
}
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
index d5f38231d3450..580b8e1114f94 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
@@ -103,7 +103,8 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
lastProgress.stateOperators.head.customMetrics.keySet().asScala == Set(
"loadedMapCacheHitCount",
"loadedMapCacheMissCount",
- "stateOnCurrentVersionSizeBytes"))
+ "stateOnCurrentVersionSizeBytes",
+ "SnapshotLastUploaded.partition_0_default"))
assert(lastProgress.sources.nonEmpty)
assert(lastProgress.sink.description == "MemorySink")
assert(lastProgress.observedMetrics.isEmpty)
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala
new file mode 100644
index 0000000000000..310b50dac1cc3
--- /dev/null
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala
@@ -0,0 +1,522 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.streaming
+
+import java.io.{BufferedWriter, File, FileWriter}
+import java.nio.file.Paths
+import java.sql.Timestamp
+
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.timeout
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row}
+import org.apache.spark.sql.connect.SparkSession
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.{ListState, MapState, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode, TimerValues, TTLConfig, ValueState}
+import org.apache.spark.sql.types._
+
+case class InputRowForConnectTest(key: String, value: String)
+case class OutputRowForConnectTest(key: String, value: String)
+case class StateRowForConnectTest(count: Long)
+
+// A basic stateful processor which will return the occurrences of key
+class BasicCountStatefulProcessor
+ extends StatefulProcessor[String, InputRowForConnectTest, OutputRowForConnectTest]
+ with Logging {
+ @transient protected var _countState: ValueState[StateRowForConnectTest] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ _countState = getHandle.getValueState[StateRowForConnectTest](
+ "countState",
+ Encoders.product[StateRowForConnectTest],
+ TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[InputRowForConnectTest],
+ timerValues: TimerValues): Iterator[OutputRowForConnectTest] = {
+ val count = inputRows.toSeq.length + {
+ if (_countState.exists()) {
+ _countState.get().count
+ } else {
+ 0L
+ }
+ }
+ _countState.update(StateRowForConnectTest(count))
+ Iterator(OutputRowForConnectTest(key, count.toString))
+ }
+}
+
+// A stateful processor with initial state which will return the occurrences of key
+class TestInitialStatefulProcessor
+ extends StatefulProcessorWithInitialState[
+ String,
+ (String, String),
+ (String, String),
+ (String, String, String)]
+ with Logging {
+ @transient protected var _countState: ValueState[Long] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong, TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[(String, String)],
+ timerValues: TimerValues): Iterator[(String, String)] = {
+ val count = inputRows.toSeq.length + {
+ if (_countState.exists()) {
+ _countState.get()
+ } else {
+ 0L
+ }
+ }
+ _countState.update(count)
+ Iterator((key, count.toString))
+ }
+
+ override def handleInitialState(
+ key: String,
+ initialState: (String, String, String),
+ timerValues: TimerValues): Unit = {
+ val count = 1 + {
+ if (_countState.exists()) {
+ _countState.get()
+ } else {
+ 0L
+ }
+ }
+ _countState.update(count)
+ }
+}
+
+case class OutputEventTimeRow(key: String, outputTimestamp: Timestamp)
+
+// A stateful processor which will return timestamp of the first item from input rows
+class ChainingOfOpsStatefulProcessor
+ extends StatefulProcessor[String, (String, Timestamp), OutputEventTimeRow] {
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {}
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[(String, Timestamp)],
+ timerValues: TimerValues): Iterator[OutputEventTimeRow] = {
+ val timestamp = inputRows.next()._2
+ Iterator(OutputEventTimeRow(key, timestamp))
+ }
+}
+
+// A basic stateful processor contains composite state variables and TTL
+class TTLTestStatefulProcessor
+ extends StatefulProcessor[String, (String, String), (String, String)] {
+ import java.time.Duration
+
+ @transient protected var countState: ValueState[Int] = _
+ @transient protected var ttlCountState: ValueState[Int] = _
+ @transient protected var ttlListState: ListState[Int] = _
+ @transient protected var ttlMapState: MapState[String, Int] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt, TTLConfig.NONE)
+ ttlCountState = getHandle
+ .getValueState[Int]("ttlCountState", Encoders.scalaInt, TTLConfig(Duration.ofMillis(1000)))
+ ttlListState = getHandle
+ .getListState[Int]("ttlListState", Encoders.scalaInt, TTLConfig(Duration.ofMillis(1000)))
+ ttlMapState = getHandle.getMapState[String, Int](
+ "ttlMapState",
+ Encoders.STRING,
+ Encoders.scalaInt,
+ TTLConfig(Duration.ofMillis(1000)))
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[(String, String)],
+ timerValues: TimerValues): Iterator[(String, String)] = {
+ val numOfInputRows = inputRows.toSeq.length
+ var count = numOfInputRows
+ var ttlCount = numOfInputRows
+ var ttlListStateCount = numOfInputRows
+ var ttlMapStateCount = numOfInputRows
+
+ if (countState.exists()) {
+ count += countState.get()
+ }
+ if (ttlCountState.exists()) {
+ ttlCount += ttlCountState.get()
+ }
+ if (ttlListState.exists()) {
+ for (value <- ttlListState.get()) {
+ ttlListStateCount += value
+ }
+ }
+ if (ttlMapState.exists()) {
+ ttlMapStateCount = ttlMapState.getValue(key)
+ }
+ countState.update(count)
+ if (key != "0") {
+ ttlCountState.update(ttlCount)
+ ttlListState.put(Array(ttlListStateCount, ttlListStateCount))
+ ttlMapState.updateValue(key, ttlMapStateCount)
+ }
+ val output = List(
+ (s"count-$key", count.toString),
+ (s"ttlCount-$key", ttlCount.toString),
+ (s"ttlListState-$key", ttlListStateCount.toString),
+ (s"ttlMapState-$key", ttlMapStateCount.toString))
+ output.iterator
+ }
+}
+
+class TransformWithStateConnectSuite
+ extends QueryTest
+ with RemoteSparkSession
+ with Logging
+ with BeforeAndAfterEach {
+ val testData: Seq[(String, String)] = Seq(("a", "1"), ("b", "1"), ("a", "2"))
+ val twsAdditionalSQLConf = Seq(
+ "spark.sql.streaming.stateStore.providerClass" ->
+ "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
+ "spark.sql.shuffle.partitions" -> "5",
+ "spark.sql.session.timeZone" -> "UTC",
+ "spark.sql.streaming.noDataMicroBatches.enabled" -> "false")
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+ }
+
+ override protected def afterEach(): Unit = {
+ try {
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ test("transformWithState - streaming with state variable, case class type") {
+ withSQLConf(twsAdditionalSQLConf: _*) {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ testData
+ .toDS()
+ .toDF("key", "value")
+ .repartition(3)
+ .write
+ .parquet(path)
+
+ val testSchema =
+ StructType(Array(StructField("key", StringType), StructField("value", StringType)))
+
+ val q = spark.readStream
+ .schema(testSchema)
+ .option("maxFilesPerTrigger", 1)
+ .parquet(path)
+ .as[InputRowForConnectTest]
+ .groupByKey(x => x.key)
+ .transformWithState[OutputRowForConnectTest](
+ new BasicCountStatefulProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+ .writeStream
+ .format("memory")
+ .queryName("my_sink")
+ .start()
+
+ try {
+ q.processAllAvailable()
+ eventually(timeout(30.seconds)) {
+ checkDatasetUnorderly(
+ spark.table("my_sink").toDF().as[(String, String)],
+ ("a", "1"),
+ ("a", "2"),
+ ("b", "1"))
+ }
+ } finally {
+ q.stop()
+ }
+ }
+ }
+ }
+
+ test("transformWithState - streaming with initial state") {
+ withSQLConf(twsAdditionalSQLConf: _*) {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ testData
+ .toDS()
+ .toDF("key", "value")
+ .repartition(3)
+ .write
+ .parquet(path)
+
+ val testSchema =
+ StructType(Array(StructField("key", StringType), StructField("value", StringType)))
+
+ val initDf = Seq(("init_1", "40.0", "a"), ("init_2", "100.0", "b"))
+ .toDS()
+ .groupByKey(x => x._3)
+ .mapValues(x => x)
+
+ val q = spark.readStream
+ .schema(testSchema)
+ .option("maxFilesPerTrigger", 1)
+ .parquet(path)
+ .as[(String, String)]
+ .groupByKey(x => x._1)
+ .transformWithState(
+ new TestInitialStatefulProcessor(),
+ TimeMode.None(),
+ OutputMode.Update(),
+ initialState = initDf)
+ .writeStream
+ .format("memory")
+ .queryName("my_sink")
+ .start()
+
+ try {
+ q.processAllAvailable()
+ eventually(timeout(30.seconds)) {
+ checkDatasetUnorderly(
+ spark.table("my_sink").toDF().as[(String, String)],
+ ("a", "2"),
+ ("a", "3"),
+ ("b", "2"))
+ }
+ } finally {
+ q.stop()
+ }
+ }
+ }
+ }
+
+ test("transformWithState - streaming with chaining of operators") {
+ withSQLConf(twsAdditionalSQLConf: _*) {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ def timestamp(num: Int): Timestamp = {
+ new Timestamp(num * 1000)
+ }
+
+ val checkResultFunc: (Dataset[Row], Long) => Unit = { (batchDF, batchId) =>
+ val realDf = batchDF.collect().toSet
+ if (batchId == 0) {
+ assert(realDf.isEmpty, s"BatchId: $batchId, RealDF: $realDf")
+ } else if (batchId == 1) {
+ // eviction watermark = 15 - 5 = 10 (max event time from batch 0),
+ // late event watermark = 0 (eviction event time from batch 0)
+ val expectedDF = Seq(Row(timestamp(10), 1L)).toSet
+ assert(
+ realDf == expectedDF,
+ s"BatchId: $batchId, expectedDf: $expectedDF, RealDF: $realDf")
+ } else if (batchId == 2) {
+ // eviction watermark = 25 - 5 = 20, late event watermark = 10;
+ // row with watermark=5<10 is dropped so it does not show up in the results;
+ // row with eventTime<=20 are finalized and emitted
+ val expectedDF = Seq(Row(timestamp(11), 1L), Row(timestamp(15), 1L)).toSet
+ assert(
+ realDf == expectedDF,
+ s"BatchId: $batchId, expectedDf: $expectedDF, RealDF: $realDf")
+ }
+ }
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val curTime = System.currentTimeMillis
+ val file1 = prepareInputData(path + "/text-test3.csv", Seq("a", "b"), Seq(10, 15))
+ file1.setLastModified(curTime + 2L)
+ val file2 = prepareInputData(path + "/text-test4.csv", Seq("a", "c"), Seq(11, 25))
+ file2.setLastModified(curTime + 4L)
+ val file3 = prepareInputData(path + "/text-test1.csv", Seq("a"), Seq(5))
+ file3.setLastModified(curTime + 6L)
+
+ val q = buildTestDf(path, spark)
+ .select(col("key").as("key"), timestamp_seconds(col("value")).as("eventTime"))
+ .withWatermark("eventTime", "5 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(x => x._1)
+ .transformWithState[OutputEventTimeRow](
+ new ChainingOfOpsStatefulProcessor(),
+ "outputTimestamp",
+ OutputMode.Append())
+ .groupBy("outputTimestamp")
+ .count()
+ .writeStream
+ .foreachBatch(checkResultFunc)
+ .outputMode("Append")
+ .start()
+
+ q.processAllAvailable()
+ eventually(timeout(30.seconds)) {
+ q.stop()
+ }
+ }
+ }
+ }
+
+ test("transformWithState - streaming with TTL and composite state variables") {
+ withSQLConf(twsAdditionalSQLConf: _*) {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ val checkResultFunc = (batchDF: Dataset[(String, String)], batchId: Long) => {
+ if (batchId == 0) {
+ val expectedDF = Set(
+ ("count-0", "1"),
+ ("ttlCount-0", "1"),
+ ("ttlListState-0", "1"),
+ ("ttlMapState-0", "1"),
+ ("count-1", "1"),
+ ("ttlCount-1", "1"),
+ ("ttlListState-1", "1"),
+ ("ttlMapState-1", "1"))
+
+ val realDf = batchDF.collect().toSet
+ assert(realDf == expectedDF)
+
+ } else if (batchId == 1) {
+ val expectedDF = Set(
+ ("count-0", "2"),
+ ("ttlCount-0", "1"),
+ ("ttlListState-0", "1"),
+ ("ttlMapState-0", "1"),
+ ("count-1", "2"),
+ ("ttlCount-1", "1"),
+ ("ttlListState-1", "1"),
+ ("ttlMapState-1", "1"))
+
+ val realDf = batchDF.collect().toSet
+ assert(realDf == expectedDF)
+ }
+
+ if (batchId == 0) {
+ // let ttl state expires
+ Thread.sleep(2000)
+ }
+ }
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val curTime = System.currentTimeMillis
+ val file1 = prepareInputData(path + "/text-test3.csv", Seq("1", "0"), Seq(0, 0))
+ file1.setLastModified(curTime + 2L)
+ val file2 = prepareInputData(path + "/text-test4.csv", Seq("1", "0"), Seq(0, 0))
+ file2.setLastModified(curTime + 4L)
+
+ val q = buildTestDf(path, spark)
+ .as[(String, String)]
+ .groupByKey(x => x._1)
+ .transformWithState(
+ new TTLTestStatefulProcessor(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+ .writeStream
+ .foreachBatch(checkResultFunc)
+ .outputMode("Update")
+ .start()
+ q.processAllAvailable()
+
+ eventually(timeout(30.seconds)) {
+ q.stop()
+ }
+ }
+ }
+ }
+
+ test("transformWithState - batch query") {
+ withSQLConf(twsAdditionalSQLConf: _*) {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ testData
+ .toDS()
+ .toDF("key", "value")
+ .repartition(3)
+ .write
+ .parquet(path)
+
+ val testSchema =
+ StructType(Array(StructField("key", StringType), StructField("value", StringType)))
+
+ spark.read
+ .schema(testSchema)
+ .parquet(path)
+ .as[InputRowForConnectTest]
+ .groupByKey(x => x.key)
+ .transformWithState[OutputRowForConnectTest](
+ new BasicCountStatefulProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+ .write
+ .saveAsTable("my_sink")
+
+ checkDatasetUnorderly(
+ spark.table("my_sink").toDF().as[(String, String)],
+ ("a", "2"),
+ ("b", "1"))
+ }
+ }
+ }
+
+ /* Utils functions for tests */
+ def prepareInputData(inputPath: String, col1: Seq[String], col2: Seq[Int]): File = {
+ // Ensure the parent directory exists
+ val file = Paths.get(inputPath).toFile
+ val parentDir = file.getParentFile
+ if (parentDir != null && !parentDir.exists()) {
+ parentDir.mkdirs()
+ }
+
+ val writer = new BufferedWriter(new FileWriter(inputPath))
+ try {
+ col1.zip(col2).foreach { case (e1, e2) =>
+ writer.write(s"$e1, $e2\n")
+ }
+ } finally {
+ writer.close()
+ }
+ file
+ }
+
+ def buildTestDf(inputPath: String, sparkSession: SparkSession): DataFrame = {
+ sparkSession.readStream
+ .format("csv")
+ .schema(
+ new StructType()
+ .add(StructField("key", StringType))
+ .add(StructField("value", StringType)))
+ .option("maxFilesPerTrigger", 1)
+ .load(inputPath)
+ .select(col("key").as("key"), col("value").cast("integer"))
+ }
+}
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala
index d38d9f3017a90..6e20db5d34938 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/IntegrationTestUtils.scala
@@ -90,6 +90,11 @@ object IntegrationTestUtils {
Files.exists(Paths.get(filePath))
}
+ lazy val isAssemblyJarsDirExists: Boolean = {
+ val filePath = s"$sparkHome/assembly/target/$scalaDir/jars/"
+ Files.exists(Paths.get(filePath))
+ }
+
private[sql] def cleanUpHiveClassesDirIfNeeded(): Unit = {
def delete(f: File): Unit = {
if (f.exists()) {
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
index 4ec056da9f17d..059b8827b4b65 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
@@ -23,10 +23,12 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration.FiniteDuration
-import org.scalatest.{BeforeAndAfterAll, Suite}
+import org.scalactic.source.Position
+import org.scalatest.{BeforeAndAfterAll, Suite, Tag}
import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.timeout
-import org.scalatest.time.SpanSugar._
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+import org.scalatest.time.SpanSugar._ // scalastyle:ignore
import org.apache.spark.SparkBuildInfo
import org.apache.spark.sql.connect.SparkSession
@@ -205,23 +207,43 @@ object SparkConnectServerUtils {
}
}
-trait RemoteSparkSession extends BeforeAndAfterAll { self: Suite =>
+trait RemoteSparkSession
+ extends AnyFunSuite // scalastyle:ignore funsuite
+ with BeforeAndAfterAll { self: Suite =>
import SparkConnectServerUtils._
var spark: SparkSession = _
protected lazy val serverPort: Int = port
override def beforeAll(): Unit = {
super.beforeAll()
- spark = createSparkSession()
+ if (IntegrationTestUtils.isAssemblyJarsDirExists) {
+ spark = createSparkSession()
+ }
}
override def afterAll(): Unit = {
+ def isArrowAllocatorIssue(message: String): Boolean = {
+ Option(message).exists(m =>
+ m.contains("closed with outstanding") ||
+ m.contains("Memory leaked"))
+ }
try {
if (spark != null) spark.stop()
} catch {
+ case e: IllegalStateException if isArrowAllocatorIssue(e.getMessage) =>
+ throw e
case e: Throwable => debug(e)
}
spark = null
super.afterAll()
}
+
+ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
+ pos: Position): Unit = {
+ if (IntegrationTestUtils.isAssemblyJarsDirExists) {
+ super.test(testName, testTags: _*)(testFun)
+ } else {
+ super.ignore(testName, testTags: _*)(testFun)
+ }
+ }
}
diff --git a/sql/connect/common/pom.xml b/sql/connect/common/pom.xml
index e6745df9013ec..18fb06ff3341d 100644
--- a/sql/connect/common/pom.xml
+++ b/sql/connect/common/pom.xml
@@ -136,7 +136,7 @@
org.apache.maven.pluginsmaven-shade-plugin
-
+ false
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 4dcaa9a40142e..df907a84868fe 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -483,11 +483,15 @@ message SubqueryExpression {
// (Optional) Options specific to table arguments.
optional TableArgOptions table_arg_options = 3;
+ // (Optional) IN subquery values.
+ repeated Expression in_subquery_values = 4;
+
enum SubqueryType {
SUBQUERY_TYPE_UNKNOWN = 0;
SUBQUERY_TYPE_SCALAR = 1;
SUBQUERY_TYPE_EXISTS = 2;
SUBQUERY_TYPE_TABLE_ARG = 3;
+ SUBQUERY_TYPE_IN = 4;
}
// Nested message for table argument options.
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index 6e469bb9027e1..22c3ca7e6e903 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -36,6 +36,8 @@ message MlCommand {
Write write = 4;
Read read = 5;
Evaluate evaluate = 6;
+ CleanCache clean_cache = 7;
+ GetCacheInfo get_cache_info = 8;
}
// Command for estimator.fit(dataset)
@@ -48,12 +50,18 @@ message MlCommand {
Relation dataset = 3;
}
- // Command to delete the cached object which could be a model
+ // Command to delete the cached objects which could be a model
// or summary evaluated by a model
message Delete {
- ObjectRef obj_ref = 1;
+ repeated ObjectRef obj_refs = 1;
}
+ // Force to clean up all the ML cached objects
+ message CleanCache { }
+
+ // Get the information of all the ML cached objects
+ message GetCacheInfo { }
+
// Command to write ML operator
message Write {
// It could be an estimator/evaluator or the cached model
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
index 419ac3b7f74ae..ec169ba114a3d 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
@@ -626,16 +626,6 @@ class Dataset[T] private[sql] (
def transpose(): DataFrame =
buildTranspose(Seq.empty)
- /** @inheritdoc */
- def scalar(): Column = {
- Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.SCALAR))
- }
-
- /** @inheritdoc */
- def exists(): Column = {
- Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.EXISTS))
- }
-
/** @inheritdoc */
def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getLimitBuilder
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
index b15e8c28df744..090907a538c72 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
@@ -141,7 +141,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends sql.KeyValueGroupedDa
statefulProcessor: StatefulProcessor[K, V, U],
timeMode: TimeMode,
outputMode: OutputMode): Dataset[U] =
- unsupported()
+ transformWithStateHelper(statefulProcessor, timeMode, outputMode)
/** @inheritdoc */
private[sql] def transformWithState[U: Encoder, S: Encoder](
@@ -149,20 +149,40 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends sql.KeyValueGroupedDa
timeMode: TimeMode,
outputMode: OutputMode,
initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] =
- unsupported()
+ transformWithStateHelper(statefulProcessor, timeMode, outputMode, Some(initialState))
/** @inheritdoc */
override private[sql] def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
eventTimeColumnName: String,
- outputMode: OutputMode): Dataset[U] = unsupported()
+ outputMode: OutputMode): Dataset[U] =
+ transformWithStateHelper(
+ statefulProcessor,
+ TimeMode.EventTime(),
+ outputMode,
+ eventTimeColumnName = eventTimeColumnName)
/** @inheritdoc */
override private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
eventTimeColumnName: String,
outputMode: OutputMode,
- initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported()
+ initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] =
+ transformWithStateHelper(
+ statefulProcessor,
+ TimeMode.EventTime(),
+ outputMode,
+ Some(initialState),
+ eventTimeColumnName)
+
+ // This is an interface, and it should not be used. The real implementation is in the
+ // inherited class.
+ protected[sql] def transformWithStateHelper[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ timeMode: TimeMode,
+ outputMode: OutputMode,
+ initialState: Option[sql.KeyValueGroupedDataset[K, S]] = None,
+ eventTimeColumnName: String = ""): Dataset[U] = unsupported()
// Overrides...
/** @inheritdoc */
@@ -602,7 +622,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
}
val initialStateImpl = if (initialState.isDefined) {
- assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]])
initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]
} else {
null
@@ -632,6 +651,53 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
}
}
+ override protected[sql] def transformWithStateHelper[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ timeMode: TimeMode,
+ outputMode: OutputMode,
+ initialState: Option[sql.KeyValueGroupedDataset[K, S]] = None,
+ eventTimeColumnName: String = ""): Dataset[U] = {
+ val outputEncoder = agnosticEncoderFor[U]
+ val stateEncoder = agnosticEncoderFor[S]
+ val inputEncoders: Seq[AgnosticEncoder[_]] = Seq(kEncoder, stateEncoder, ivEncoder)
+
+ // SparkUserDefinedFunction is creating a udfPacket where the input function are
+ // being java serialized into bytes; we pass in `statefulProcessor` as function so it can be
+ // serialized into bytes and deserialized back on connect server
+ val sparkUserDefinedFunc =
+ SparkUserDefinedFunction(statefulProcessor, inputEncoders, outputEncoder)
+ val funcProto = UdfToProtoUtils.toProto(sparkUserDefinedFunc)
+
+ val initialStateImpl = if (initialState.isDefined) {
+ initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]
+ } else {
+ null
+ }
+
+ sparkSession.newDataset[U](outputEncoder) { builder =>
+ val twsBuilder = builder.getGroupMapBuilder
+ val twsInfoBuilder = proto.TransformWithStateInfo.newBuilder()
+ if (!eventTimeColumnName.isEmpty) {
+ twsInfoBuilder.setEventTimeColumnName(eventTimeColumnName)
+ }
+ twsBuilder
+ .setInput(plan.getRoot)
+ .addAllGroupingExpressions(groupingExprs)
+ .setFunc(funcProto)
+ .setOutputMode(outputMode.toString)
+ .setTransformWithStateInfo(
+ twsInfoBuilder
+ // we pass time mode as string here and deterministically restored on server
+ .setTimeMode(timeMode.toString)
+ .build())
+ if (initialStateImpl != null) {
+ twsBuilder
+ .addAllInitialGroupingExpressions(initialStateImpl.groupingExprs)
+ .setInitialInput(initialStateImpl.plan.getRoot)
+ }
+ }
+ }
+
private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: AgnosticEncoder[U])(
inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction = {
val inputEncoders = kEncoder +: inEncoders // Apply keyAs changes by setting kEncoder
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index 0af7c7b6d97a7..739b0318759e5 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -49,10 +49,11 @@ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
+import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
-import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
+import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf, SubqueryExpression}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ExecutionListenerManager
@@ -420,8 +421,8 @@ class SparkSession private[sql] (
@DeveloperApi
def newDataset[T](encoder: AgnosticEncoder[T], cols: Seq[Column])(
f: proto.Relation.Builder => Unit): Dataset[T] = {
- val references = cols.flatMap(_.node.collect { case n: SubqueryExpressionNode =>
- n.relation
+ val references: Seq[proto.Relation] = cols.flatMap(_.node.collect {
+ case n: SubqueryExpression => n.ds.plan.getRoot
})
val builder = proto.Relation.newBuilder()
@@ -744,7 +745,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
lazy val isAPIModeConnect =
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
.getOrElse("classic")
- .toLowerCase(Locale.ROOT) == "connect"
+ .toLowerCase(Locale.ROOT) == "connect" || System.getenv("SPARK_CONNECT_MODE") == "1"
val remoteString = sparkOptions
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
@@ -757,13 +758,14 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}
+ lazy val serverId = UUID.randomUUID().toString
+
server.synchronized {
if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
val token = java.util.UUID.randomUUID().toString()
- val serverId = UUID.randomUUID().toString
server = Some {
val args =
Seq(
@@ -778,6 +780,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
+ pb.environment().put("SPARK_CONNECT_MODE", "0")
pb.environment().put("SPARK_IDENT_STRING", serverId)
pb.environment().put("HOSTNAME", "local")
pb.environment().put("SPARK_CONNECT_AUTHENTICATE_TOKEN", token)
@@ -788,6 +791,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
Option(System.getenv("SPARK_LOG_DIR"))
.orElse(Option(System.getenv("SPARK_HOME")).map(p => Paths.get(p, "logs").toString))
.foreach { p =>
+ Files.createDirectories(Paths.get(p))
val logFile = Paths
.get(
p,
@@ -802,14 +806,18 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}
+ // Let the server fully start to make less noise from retrying.
+ Thread.sleep(1000L)
+
System.setProperty("spark.remote", s"sc://localhost/;token=$token")
// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = server.synchronized {
if (server.isDefined) {
- new ProcessBuilder(maybeConnectStopScript.get.toString)
- .start()
+ val builder = new ProcessBuilder(maybeConnectStopScript.get.toString)
+ builder.environment().put("SPARK_IDENT_STRING", serverId)
+ builder.start()
}
}
})
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index f29291594069d..d3dae47f4c471 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -192,7 +192,6 @@ private[client] object GrpcExceptionConverter {
new ParseException(
None,
Origin(),
- Origin(),
errorClass = params.errorClass.orNull,
messageParameters = params.messageParameters,
queryContext = params.queryContext)),
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 4925c7700b61f..e844237a3bb44 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -110,6 +110,15 @@ private[sql] class SparkConnectClient(
bstub.analyzePlan(request)
}
+ private def isValidUUID(uuid: String): Boolean = {
+ try {
+ UUID.fromString(uuid)
+ true
+ } catch {
+ case _: IllegalArgumentException => false
+ }
+ }
+
/**
* Execute the plan and return response iterator.
*
@@ -117,7 +126,9 @@ private[sql] class SparkConnectClient(
* done. If you don't close it, it and the underlying data will be cleaned up once the iterator
* is garbage collected.
*/
- def execute(plan: proto.Plan): CloseableIterator[proto.ExecutePlanResponse] = {
+ def execute(
+ plan: proto.Plan,
+ operationId: Option[String] = None): CloseableIterator[proto.ExecutePlanResponse] = {
artifactManager.uploadAllClassFileArtifacts()
val request = proto.ExecutePlanRequest
.newBuilder()
@@ -127,6 +138,13 @@ private[sql] class SparkConnectClient(
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
+ operationId.foreach { opId =>
+ require(
+ isValidUUID(opId),
+ s"Invalid operationId: $opId. The id must be an UUID string of " +
+ "the format `00112233-4455-6677-8899-aabbccddeeff`")
+ request.setOperationId(opId)
+ }
if (configuration.useReattachableExecute) {
bstub.executePlanReattachable(request.build())
} else {
@@ -423,7 +441,6 @@ object SparkConnectClient {
def configuration: Configuration = _configuration
def userId(id: String): Builder = {
- // TODO this is not an optional field!
require(id != null && id.nonEmpty)
_configuration = _configuration.copy(userId = id)
this
@@ -706,12 +723,15 @@ object SparkConnectClient {
s"os/$osName").mkString(" ")
}
+ private lazy val sparkUser =
+ sys.env.getOrElse("SPARK_USER", System.getProperty("user.name", null))
+
/**
* Helper class that fully captures the configuration for a [[SparkConnectClient]].
*/
private[sql] case class Configuration(
- userId: String = null,
- userName: String = null,
+ userId: String = sparkUser,
+ userName: String = sparkUser,
host: String = "localhost",
port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT,
token: Option[String] = None,
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 4618c7e24d4ac..ceeece073da65 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -388,7 +388,7 @@ object ArrowDeserializers {
}
}
- case (TransformingEncoder(_, encoder, provider), v) =>
+ case (TransformingEncoder(_, encoder, provider, _), v) =>
new Deserializer[Any] {
private[this] val codec = provider()
private[this] val deserializer = deserializerFor(encoder, v, timeZoneId)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index 584a318f039d8..d79fb25ec1a0b 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -52,6 +52,12 @@ class ArrowSerializer[T](
private[this] val allocator: BufferAllocator,
private[this] val timeZoneId: String,
private[this] val largeVarTypes: Boolean) {
+
+ // SPARK-51079: keep the old constructor for backward-compatibility.
+ def this(enc: AgnosticEncoder[T], allocator: BufferAllocator, timeZoneId: String) = {
+ this(enc, allocator, timeZoneId, false)
+ }
+
private val (root, serializer) =
ArrowSerializer.serializerFor(enc, allocator, timeZoneId, largeVarTypes)
private val vectors = root.getFieldVectors.asScala
@@ -485,7 +491,7 @@ object ArrowSerializer {
o => getter.invoke(o)
}
- case (TransformingEncoder(_, encoder, provider), v) =>
+ case (TransformingEncoder(_, encoder, provider, _), v) =>
new Serializer {
private[this] val codec = provider().asInstanceOf[Codec[Any, Any]]
private[this] val delegate: Serializer = serializerFor(encoder, v)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
index 54f45a434826a..1e798387726bb 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
@@ -27,10 +27,11 @@ import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBounda
import org.apache.spark.sql.{functions, Column, Encoder}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
+import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
-import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame}
+import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, SubqueryExpression, SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame}
/**
* Converter for [[ColumnNode]] to [[proto.Expression]] conversions.
@@ -218,11 +219,15 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
case LazyExpression(child, _) =>
return apply(child, e)
- case SubqueryExpressionNode(relation, subqueryType, _) =>
+ case SubqueryExpression(ds, subqueryType, _) =>
+ val relation = ds.plan.getRoot
val b = builder.getSubqueryExpressionBuilder
b.setSubqueryType(subqueryType match {
case SubqueryType.SCALAR => proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_SCALAR
case SubqueryType.EXISTS => proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_EXISTS
+ case SubqueryType.IN(values) =>
+ b.addAllInSubqueryValues(values.map(value => apply(value, e)).asJava)
+ proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_IN
})
assert(relation.hasCommon && relation.getCommon.hasPlanId)
b.setPlanId(relation.getCommon.getPlanId)
@@ -311,22 +316,3 @@ case class ProtoColumnNode(
override def sql: String = expr.toString
override def children: Seq[ColumnNodeLike] = Seq.empty
}
-
-sealed trait SubqueryType
-
-object SubqueryType {
- case object SCALAR extends SubqueryType
- case object EXISTS extends SubqueryType
-}
-
-case class SubqueryExpressionNode(
- relation: proto.Relation,
- subqueryType: SubqueryType,
- override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
- override def sql: String = subqueryType match {
- case SubqueryType.SCALAR => s"($relation)"
- case _ => s"$subqueryType ($relation)"
- }
- override def children: Seq[ColumnNodeLike] = Seq.empty
-}
diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_from_json.json b/sql/connect/common/src/test/resources/query-tests/queries/function_from_json.json
index 3c359d024c24b..43b1abd1d59ad 100644
--- a/sql/connect/common/src/test/resources/query-tests/queries/function_from_json.json
+++ b/sql/connect/common/src/test/resources/query-tests/queries/function_from_json.json
@@ -37,7 +37,7 @@
}
}, {
"literal": {
- "string": "{\"type\":\"struct\",\"fields\":[{\"name\":\"id\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"a\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"b\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}"
+ "string": "STRUCT\u003cid: BIGINT, a: INT, b: DOUBLE\u003e"
},
"common": {
"origin": {
diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_from_json.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_from_json.proto.bin
index 001e2bd467409..59ecbec228612 100644
Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_from_json.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_from_json.proto.bin differ
diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_from_xml.json b/sql/connect/common/src/test/resources/query-tests/queries/function_from_xml.json
index 04ffe209f170b..44faa65e6a0ce 100644
--- a/sql/connect/common/src/test/resources/query-tests/queries/function_from_xml.json
+++ b/sql/connect/common/src/test/resources/query-tests/queries/function_from_xml.json
@@ -37,7 +37,7 @@
}
}, {
"literal": {
- "string": "{\"type\":\"struct\",\"fields\":[{\"name\":\"id\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"a\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"b\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}"
+ "string": "STRUCT\u003cid: BIGINT, a: INT, b: DOUBLE\u003e"
},
"common": {
"origin": {
diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_from_xml.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_from_xml.proto.bin
index d4f149dc9f4f2..c40541dc98ee4 100644
Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_from_xml.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_from_xml.proto.bin differ
diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml
index bb30b01f778de..8da2b860fadfe 100644
--- a/sql/connect/server/pom.xml
+++ b/sql/connect/server/pom.xml
@@ -272,41 +272,11 @@
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
-
- org.codehaus.mojo
- build-helper-maven-plugin
-
-
- add-sources
- generate-sources
-
- add-source
-
-
-
- src/main/scala-${scala.binary.version}
-
-
-
-
- add-scala-test-sources
- generate-test-sources
-
- add-test-source
-
-
-
- src/test/gen-java
-
-
-
-
- org.apache.maven.pluginsmaven-shade-plugin
-
+ false
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 9f884b683079c..1b9f770e9e96a 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -29,12 +29,14 @@ object Connect {
val CONNECT_GRPC_BINDING_ADDRESS =
buildStaticConf("spark.connect.grpc.binding.address")
+ .doc("The address for Spark Connect server to bind.")
.version("4.0.0")
.stringConf
.createOptional
val CONNECT_GRPC_BINDING_PORT =
buildStaticConf("spark.connect.grpc.binding.port")
+ .doc("The port for Spark Connect server to bind.")
.version("3.4.0")
.intConf
.createWithDefault(ConnectCommon.CONNECT_GRPC_BINDING_PORT)
@@ -331,4 +333,24 @@ object Connect {
Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV))
}
}
+
+ val CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE =
+ buildConf("spark.connect.session.connectML.mlCache.maxSize")
+ .doc("Maximum size of the MLCache per session. The cache will evict the least recently" +
+ "used models if the size exceeds this limit. The size is in bytes.")
+ .version("4.1.0")
+ .internal()
+ .bytesConf(ByteUnit.BYTE)
+ // By default, 1/3 of total designated memory (the configured -Xmx).
+ .createWithDefault(Runtime.getRuntime.maxMemory() / 3)
+
+ val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT =
+ buildConf("spark.connect.session.connectML.mlCache.timeout")
+ .doc(
+ "Timeout of models in MLCache. Models will be evicted from the cache if they are not " +
+ "used for this amount of time. The timeout is in minutes.")
+ .version("4.1.0")
+ .internal()
+ .timeConf(TimeUnit.MINUTES)
+ .createWithDefault(15)
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index e8d8585020722..05fa976b5beab 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -18,27 +18,59 @@ package org.apache.spark.sql.connect.ml
import java.util.UUID
import java.util.concurrent.{ConcurrentMap, TimeUnit}
+import java.util.concurrent.atomic.AtomicLong
-import com.google.common.cache.CacheBuilder
+import scala.collection.mutable
+
+import com.google.common.cache.{CacheBuilder, RemovalNotification}
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.Model
import org.apache.spark.ml.util.ConnectHelper
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.service.SessionHolder
/**
* MLCache is for caching ML objects, typically for models and summaries evaluated by a model.
*/
-private[connect] class MLCache extends Logging {
+private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
private val helper = new ConnectHelper()
private val helperID = "______ML_CONNECT_HELPER______"
- private val cachedModel: ConcurrentMap[String, Object] = CacheBuilder
+ private def getMaxCacheSizeKB: Long = {
+ sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE) / 1024
+ }
+
+ private def getTimeoutMinute: Long = {
+ sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT)
+ }
+
+ private[ml] case class CacheItem(obj: Object, sizeBytes: Long)
+ private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = CacheBuilder
.newBuilder()
.softValues()
- .maximumSize(MLCache.MAX_CACHED_ITEMS)
- .expireAfterAccess(MLCache.CACHE_TIMEOUT_MINUTE, TimeUnit.MINUTES)
- .build[String, Object]()
+ .maximumWeight(getMaxCacheSizeKB)
+ .expireAfterAccess(getTimeoutMinute, TimeUnit.MINUTES)
+ .weigher((key: String, value: CacheItem) => {
+ Math.ceil(value.sizeBytes.toDouble / 1024).toInt
+ })
+ .removalListener((removed: RemovalNotification[String, CacheItem]) =>
+ totalSizeBytes.addAndGet(-removed.getValue.sizeBytes))
+ .build[String, CacheItem]()
.asMap()
+ private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0)
+
+ private def estimateObjectSize(obj: Object): Long = {
+ obj match {
+ case model: Model[_] =>
+ model.asInstanceOf[Model[_]].estimatedSize
+ case _ =>
+ // There can only be Models in the cache, so we should never reach here.
+ 1
+ }
+ }
+
/**
* Cache an object into a map of MLCache, and return its key
* @param obj
@@ -48,7 +80,9 @@ private[connect] class MLCache extends Logging {
*/
def register(obj: Object): String = {
val objectId = UUID.randomUUID().toString
- cachedModel.put(objectId, obj)
+ val sizeBytes = estimateObjectSize(obj)
+ totalSizeBytes.addAndGet(sizeBytes)
+ cachedModel.put(objectId, CacheItem(obj, sizeBytes))
objectId
}
@@ -63,7 +97,7 @@ private[connect] class MLCache extends Logging {
if (refId == helperID) {
helper
} else {
- cachedModel.get(refId)
+ Option(cachedModel.get(refId)).map(_.obj).orNull
}
}
@@ -72,22 +106,26 @@ private[connect] class MLCache extends Logging {
* @param refId
* the key used to look up the corresponding object
*/
- def remove(refId: String): Unit = {
- cachedModel.remove(refId)
+ def remove(refId: String): Boolean = {
+ val removed = cachedModel.remove(refId)
+ // remove returns null if the key is not present
+ removed != null
}
/**
* Clear all the caches
*/
- def clear(): Unit = {
+ def clear(): Int = {
+ val size = cachedModel.size()
cachedModel.clear()
+ size
}
-}
-private[connect] object MLCache {
- // The maximum number of distinct items in the cache.
- private val MAX_CACHED_ITEMS = 100
-
- // The maximum time for an item to stay in the cache.
- private val CACHE_TIMEOUT_MINUTE = 60
+ def getInfo(): Array[String] = {
+ val info = mutable.ArrayBuilder.make[String]
+ cachedModel.forEach { case (key, value) =>
+ info += s"id: $key, obj: ${value.obj.getClass}, size: ${value.sizeBytes}"
+ }
+ info.result()
+ }
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 9a9e156f91cd4..5283639e4aa2b 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.connect.ml
+import scala.collection.mutable
import scala.jdk.CollectionConverters.CollectionHasAsScala
import org.apache.spark.connect.proto
@@ -155,9 +156,7 @@ private[connect] object MLHandler extends Logging {
.setObjRef(proto.ObjectRef.newBuilder().setId(id)))
.build()
case a: Array[_] if a.nonEmpty && a.forall(_.isInstanceOf[Model[_]]) =>
- val ids = a.map { m =>
- mlCache.register(m.asInstanceOf[Model[_]])
- }
+ val ids = a.map(m => mlCache.register(m.asInstanceOf[Model[_]]))
proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
@@ -171,15 +170,33 @@ private[connect] object MLHandler extends Logging {
}
case proto.MlCommand.CommandCase.DELETE =>
- val objId = mlCommand.getDelete.getObjRef.getId
- var result = false
- if (!objId.contains(".")) {
- mlCache.remove(objId)
- result = true
+ val ids = mutable.ArrayBuilder.make[String]
+ mlCommand.getDelete.getObjRefsList.asScala.toArray.foreach { objId =>
+ if (!objId.getId.contains(".")) {
+ if (mlCache.remove(objId.getId)) {
+ ids += objId.getId
+ }
+ }
}
proto.MlCommandResult
.newBuilder()
- .setParam(LiteralValueProtoConverter.toLiteralProto(result))
+ .setOperatorInfo(
+ proto.MlCommandResult.MlOperatorInfo
+ .newBuilder()
+ .setObjRef(proto.ObjectRef.newBuilder().setId(ids.result().mkString(","))))
+ .build()
+
+ case proto.MlCommand.CommandCase.CLEAN_CACHE =>
+ val size = mlCache.clear()
+ proto.MlCommandResult
+ .newBuilder()
+ .setParam(LiteralValueProtoConverter.toLiteralProto(size))
+ .build()
+
+ case proto.MlCommand.CommandCase.GET_CACHE_INFO =>
+ proto.MlCommandResult
+ .newBuilder()
+ .setParam(LiteralValueProtoConverter.toLiteralProto(mlCache.getInfo()))
.build()
case proto.MlCommand.CommandCase.WRITE =>
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index c11a153cde5b8..fb9469cd480eb 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -416,10 +416,15 @@ private[ml] object MLUtils {
if (operators.isEmpty || !operators.contains(name)) {
throw MlUnsupportedException(s"Unsupported read for $name")
}
- operators(name)
- .getMethod("load", classOf[String])
- .invoke(null, path)
- .asInstanceOf[T]
+ try {
+ operators(name)
+ .getMethod("load", classOf[String])
+ .invoke(null, path)
+ .asInstanceOf[T]
+ } catch {
+ case e: InvocationTargetException if e.getCause != null =>
+ throw e.getCause
+ }
}
/**
@@ -621,8 +626,8 @@ private[ml] object MLUtils {
"isDistributed",
"logLikelihood",
"logPerplexity",
- "describeTopics")),
- (classOf[LocalLDAModel], Set("vocabSize")),
+ "describeTopics",
+ "vocabSize")),
(
classOf[DistributedLDAModel],
Set("trainingLogLikelihood", "logPrior", "getCheckpointFiles", "toLocal")),
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 24fc1275d4823..849dd9532405c 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -46,7 +46,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Encoders, ForeachWriter, Observation, Row}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedOrdinal, UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions._
@@ -54,7 +54,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateStarAction}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TimeModes, TransformWithState, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateEventTimeWatermarkColumn, UpdateStarAction}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -81,7 +81,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
-import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.{ArrowUtils, CaseInsensitiveStringMap}
import org.apache.spark.storage.CacheId
@@ -661,7 +661,11 @@ class SparkConnectPlanner(
case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF |
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF =>
- transformTransformWithStateInPandas(pythonUdf, group, rel)
+ transformTransformWithStateInPySpark(pythonUdf, group, rel, usePandas = true)
+
+ case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF |
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
+ transformTransformWithStateInPySpark(pythonUdf, group, rel, usePandas = false)
case _ =>
throw InvalidPlanInput(
@@ -684,7 +688,71 @@ class SparkConnectPlanner(
rel.getGroupingExpressionsList,
rel.getSortingExpressionsList)
- if (rel.hasIsMapGroupsWithState) {
+ if (rel.hasTransformWithStateInfo) {
+ val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput
+
+ val twsInfo = rel.getTransformWithStateInfo
+ val keyDeserializer = udf.inputDeserializer(ds.groupingAttributes)
+ val outputAttr = udf.outputObjAttr
+
+ val timeMode = TimeModes(twsInfo.getTimeMode)
+ val outputMode = InternalOutputModes(rel.getOutputMode)
+
+ val twsNode = if (hasInitialState) {
+ val statefulProcessor = unpackedUdf.function
+ .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]]
+ val initDs = UntypedKeyValueGroupedDataset(
+ rel.getInitialInput,
+ rel.getInitialGroupingExpressionsList,
+ rel.getSortingExpressionsList)
+ new TransformWithState(
+ keyDeserializer,
+ ds.valueDeserializer,
+ ds.groupingAttributes,
+ ds.dataAttributes,
+ statefulProcessor,
+ timeMode,
+ outputMode,
+ udf.inEnc.asInstanceOf[ExpressionEncoder[Any]],
+ outputAttr,
+ ds.analyzed,
+ hasInitialState,
+ initDs.groupingAttributes,
+ initDs.dataAttributes,
+ initDs.valueDeserializer,
+ initDs.analyzed)
+ } else {
+ val statefulProcessor =
+ unpackedUdf.function.asInstanceOf[StatefulProcessor[Any, Any, Any]]
+ new TransformWithState(
+ keyDeserializer,
+ ds.valueDeserializer,
+ ds.groupingAttributes,
+ ds.dataAttributes,
+ statefulProcessor,
+ timeMode,
+ outputMode,
+ udf.inEnc.asInstanceOf[ExpressionEncoder[Any]],
+ outputAttr,
+ ds.analyzed,
+ hasInitialState,
+ ds.groupingAttributes,
+ ds.dataAttributes,
+ keyDeserializer,
+ LocalRelation(ds.vEncoder.schema))
+ }
+ val serializedPlan = SerializeFromObject(udf.outputNamedExpression, twsNode)
+
+ if (twsInfo.hasEventTimeColumnName) {
+ val eventTimeWrappedPlan = UpdateEventTimeWatermarkColumn(
+ UnresolvedAttribute(twsInfo.getEventTimeColumnName),
+ None,
+ serializedPlan)
+ eventTimeWrappedPlan
+ } else {
+ serializedPlan
+ }
+ } else if (rel.hasIsMapGroupsWithState) {
val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput
val initialDs = if (hasInitialState) {
UntypedKeyValueGroupedDataset(
@@ -1038,10 +1106,11 @@ class SparkConnectPlanner(
.logicalPlan
}
- private def transformTransformWithStateInPandas(
+ private def transformTransformWithStateInPySpark(
pythonUdf: PythonUDF,
groupedDs: RelationalGroupedDataset,
- rel: proto.GroupMap): LogicalPlan = {
+ rel: proto.GroupMap,
+ usePandas: Boolean): LogicalPlan = {
val twsInfo = rel.getTransformWithStateInfo
val outputSchema: StructType = {
transformDataType(twsInfo.getOutputSchema) match {
@@ -1067,25 +1136,52 @@ class SparkConnectPlanner(
.builder(groupedDs.df.logicalPlan.output)
.asInstanceOf[PythonUDF]
- groupedDs
- .transformWithStateInPandas(
- Column(resolvedPythonUDF),
- outputSchema,
- rel.getOutputMode,
- twsInfo.getTimeMode,
- initialStateDs,
- twsInfo.getEventTimeColumnName)
- .logicalPlan
+ if (usePandas) {
+ groupedDs
+ .transformWithStateInPandas(
+ Column(resolvedPythonUDF),
+ outputSchema,
+ rel.getOutputMode,
+ twsInfo.getTimeMode,
+ initialStateDs,
+ twsInfo.getEventTimeColumnName)
+ .logicalPlan
+ } else {
+ // use Row
+ groupedDs
+ .transformWithStateInPySpark(
+ Column(resolvedPythonUDF),
+ outputSchema,
+ rel.getOutputMode,
+ twsInfo.getTimeMode,
+ initialStateDs,
+ twsInfo.getEventTimeColumnName)
+ .logicalPlan
+ }
+
} else {
- groupedDs
- .transformWithStateInPandas(
- Column(pythonUdf),
- outputSchema,
- rel.getOutputMode,
- twsInfo.getTimeMode,
- null,
- twsInfo.getEventTimeColumnName)
- .logicalPlan
+ if (usePandas) {
+ groupedDs
+ .transformWithStateInPandas(
+ Column(pythonUdf),
+ outputSchema,
+ rel.getOutputMode,
+ twsInfo.getTimeMode,
+ null,
+ twsInfo.getEventTimeColumnName)
+ .logicalPlan
+ } else {
+ // use Row
+ groupedDs
+ .transformWithStateInPySpark(
+ Column(pythonUdf),
+ outputSchema,
+ rel.getOutputMode,
+ twsInfo.getTimeMode,
+ null,
+ twsInfo.getEventTimeColumnName)
+ .logicalPlan
+ }
}
}
@@ -2309,7 +2405,7 @@ class SparkConnectPlanner(
private def transformSortOrder(order: proto.Expression.SortOrder) = {
expressions.SortOrder(
- child = transformExpression(order.getChild),
+ child = transformSortOrderAndReplaceOrdinals(order.getChild),
direction = order.getDirection match {
case proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING =>
expressions.Ascending
@@ -2323,6 +2419,19 @@ class SparkConnectPlanner(
sameOrderExpressions = Seq.empty)
}
+ /**
+ * Transforms an input protobuf sort order expression into the Catalyst expression and converts
+ * top-level integer [[Literal]]s to [[UnresolvedOrdinal]]s, if `orderByOrdinal` is enabled.
+ */
+ private def transformSortOrderAndReplaceOrdinals(sortItem: proto.Expression) = {
+ val transformedSortItem = transformExpression(sortItem)
+ if (session.sessionState.conf.orderByOrdinal) {
+ replaceIntegerLiteralWithOrdinal(transformedSortItem)
+ } else {
+ transformedSortItem
+ }
+ }
+
private def transformDrop(rel: proto.Drop): LogicalPlan = {
var output = Dataset.ofRows(session, transformRelation(rel.getInput))
if (rel.getColumnsCount > 0) {
@@ -2375,27 +2484,28 @@ class SparkConnectPlanner(
input
}
- val groupingExprs = rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
+ val groupingExpressionsWithOrdinals = rel.getGroupingExpressionsList.asScala.toSeq
+ .map(transformGroupingExpressionAndReplaceOrdinals)
val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq
.map(expr => transformExpressionWithTypedReduceExpression(expr, logicalPlan))
- val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression)
+ val aliasedAgg = (groupingExpressionsWithOrdinals ++ aggExprs).map(toNamedExpression)
rel.getGroupType match {
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
logical.Aggregate(
- groupingExpressions = groupingExprs,
+ groupingExpressions = groupingExpressionsWithOrdinals,
aggregateExpressions = aliasedAgg,
child = logicalPlan)
case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
logical.Aggregate(
- groupingExpressions = Seq(Rollup(groupingExprs.map(Seq(_)))),
+ groupingExpressions = Seq(Rollup(groupingExpressionsWithOrdinals.map(Seq(_)))),
aggregateExpressions = aliasedAgg,
child = logicalPlan)
case proto.Aggregate.GroupType.GROUP_TYPE_CUBE =>
logical.Aggregate(
- groupingExpressions = Seq(Cube(groupingExprs.map(Seq(_)))),
+ groupingExpressions = Seq(Cube(groupingExpressionsWithOrdinals.map(Seq(_)))),
aggregateExpressions = aliasedAgg,
child = logicalPlan)
@@ -2413,21 +2523,23 @@ class SparkConnectPlanner(
.map(expressions.Literal.apply)
}
logical.Pivot(
- groupByExprsOpt = Some(groupingExprs.map(toNamedExpression)),
+ groupByExprsOpt = Some(groupingExpressionsWithOrdinals.map(toNamedExpression)),
pivotColumn = pivotExpr,
pivotValues = valueExprs,
aggregates = aggExprs,
child = logicalPlan)
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
- val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets =>
- getGroupingSets.getGroupingSetList.asScala.toSeq.map(transformExpression)
- }
+ val groupingSetsExpressionsWithOrdinals =
+ rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets =>
+ getGroupingSets.getGroupingSetList.asScala.toSeq
+ .map(transformGroupingExpressionAndReplaceOrdinals)
+ }
logical.Aggregate(
groupingExpressions = Seq(
GroupingSets(
- groupingSets = groupingSetsExprs,
- userGivenGroupByExprs = groupingExprs)),
+ groupingSets = groupingSetsExpressionsWithOrdinals,
+ userGivenGroupByExprs = groupingExpressionsWithOrdinals)),
aggregateExpressions = aliasedAgg,
child = logicalPlan)
@@ -2435,6 +2547,20 @@ class SparkConnectPlanner(
}
}
+ /**
+ * Transforms an input protobuf grouping expression into the Catalyst expression and converts
+ * top-level integer [[Literal]]s to [[UnresolvedOrdinal]]s, if `groupByOrdinal` is enabled.
+ */
+ private def transformGroupingExpressionAndReplaceOrdinals(
+ groupingExpression: proto.Expression) = {
+ val transformedGroupingExpression = transformExpression(groupingExpression)
+ if (session.sessionState.conf.groupByOrdinal) {
+ replaceIntegerLiteralWithOrdinal(transformedGroupingExpression)
+ } else {
+ transformedGroupingExpression
+ }
+ }
+
@deprecated("TypedReduce is now implemented using a normal UDAF aggregator.", "4.0.0")
private def transformTypedReduceExpression(
fun: proto.Expression.UnresolvedFunction,
@@ -3961,6 +4087,12 @@ class SparkConnectPlanner(
} else {
UnresolvedTableArgPlanId(planId)
}
+ case proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_IN =>
+ UnresolvedInSubqueryPlanId(
+ getSubqueryExpression.getInSubqueryValuesList.asScala.map { value =>
+ transformExpression(value)
+ }.toSeq,
+ planId)
case other => throw InvalidPlanInput(s"Unknown SubqueryType $other")
}
}
@@ -4002,6 +4134,16 @@ class SparkConnectPlanner(
}
}
+ /**
+ * Replaces a top-level integer [[Literal]] in a grouping expression with [[UnresolvedOrdinal]]
+ * that has the same index.
+ */
+ private def replaceIntegerLiteralWithOrdinal(groupingExpression: Expression) =
+ groupingExpression match {
+ case Literal(value: Int, IntegerType) => UnresolvedOrdinal(value)
+ case other => other
+ }
+
private def assertPlan(assertion: Boolean, message: => String = ""): Unit = {
if (!assertion) throw InvalidPlanInput(message)
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 631885a5d741c..6f252c0cd9480 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -112,7 +112,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
new ConcurrentHashMap()
// ML model cache
- private[connect] lazy val mlCache = new MLCache()
+ private[connect] lazy val mlCache = new MLCache(this)
// Mapping from id to StreamingQueryListener. Used for methods like removeListener() in
// StreamingQueryManager.
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index a156be189c650..80a580eb06990 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -74,10 +74,10 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
.build[ExecuteKey, ExecuteInfo]()
/** The time when the last execution was removed. */
- private var lastExecutionTimeNs: AtomicLong = new AtomicLong(System.nanoTime())
+ private val lastExecutionTimeNs: AtomicLong = new AtomicLong(System.nanoTime())
/** Executor for the periodic maintenance */
- private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+ private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()
/**
@@ -249,7 +249,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
// Visible for testing.
private[connect] def periodicMaintenance(timeoutNs: Long): Unit = {
// Find any detached executions that expired and should be removed.
- logInfo("Started periodic run of SparkConnectExecutionManager maintenance.")
+ logDebug("Started periodic run of SparkConnectExecutionManager maintenance.")
val nowNs = System.nanoTime()
executions.forEach((_, executeHolder) => {
@@ -266,7 +266,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}
})
- logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.")
+ logDebug("Finished periodic run of SparkConnectExecutionManager maintenance.")
}
// For testing.
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
index 445f40d25edcd..8fbcf3218a003 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
@@ -78,7 +78,7 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
* final ResultComplete response.
*/
def cleanUp(): Unit = {
- var listener = streamingQueryServerSideListener.getAndSet(null)
+ val listener = streamingQueryServerSideListener.getAndSet(null)
if (listener != null) {
sessionHolder.session.streams.removeListener(listener)
listener.sendResultComplete()
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 8fa64ddcce49e..aab338cc06ff1 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -363,7 +363,7 @@ object SparkConnectService extends Logging {
* Starts the GRPC Service.
*/
private def startGRPCService(): Unit = {
- val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", true)
+ val debugMode = SparkEnv.get.conf.getBoolean("spark.connect.grpc.debug.enabled", false)
val bindAddress = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_ADDRESS)
val startPort = SparkEnv.get.conf.get(CONNECT_GRPC_BINDING_PORT)
val sparkConnectService = new SparkConnectService(debugMode)
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index 8581bb7b98f05..572d760187e9d 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -49,7 +49,7 @@ class SparkConnectSessionManager extends Logging {
.build[SessionKey, SessionHolderInfo]()
/** Executor for the periodic maintenance */
- private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+ private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()
private def validateSessionId(
@@ -234,7 +234,7 @@ class SparkConnectSessionManager extends Logging {
defaultInactiveTimeoutMs: Long,
ignoreCustomTimeout: Boolean): Unit = {
// Find any sessions that expired and should be removed.
- logInfo("Started periodic run of SparkConnectSessionManager maintenance.")
+ logDebug("Started periodic run of SparkConnectSessionManager maintenance.")
def shouldExpire(info: SessionHolderInfo, nowMs: Long): Boolean = {
val timeoutMs = if (info.customInactiveTimeoutMs.isDefined && !ignoreCustomTimeout) {
@@ -262,7 +262,7 @@ class SparkConnectSessionManager extends Logging {
}
})
- logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
+ logDebug("Finished periodic run of SparkConnectSessionManager maintenance.")
}
private def newIsolatedSession(): SparkSession = {
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
index 3da2548b456e8..beff193f6701f 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
@@ -186,7 +186,7 @@ private[connect] class SparkConnectStreamingQueryCache(
private[service] val taggedQueries: ConcurrentMap[String, QueryCacheKeySet] =
new ConcurrentHashMap[String, QueryCacheKeySet]
- private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+ private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()
/** Schedules periodic checks if it is not already scheduled */
@@ -218,7 +218,7 @@ private[connect] class SparkConnectStreamingQueryCache(
(k, v) => {
if (v == null || !v.addKey(queryKey)) {
// Create a new QueryCacheKeySet if the entry is absent or being removed.
- var keys = mutable.HashSet.empty[QueryCacheKey]
+ val keys = mutable.HashSet.empty[QueryCacheKey]
keys.add(queryKey)
new QueryCacheKeySet(keys = keys)
} else {
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala
index c35eb58edfbe7..1f335c9ce0051 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala
@@ -43,7 +43,7 @@ private[ui] class SparkConnectServerSessionPage(parent: SparkConnectServerTab)
store
.getSession(sessionId)
.map { sessionStat =>
- generateBasicStats() ++
+ generateBasicStats(sessionId) ++
++
User
@@ -64,9 +64,12 @@ private[ui] class SparkConnectServerSessionPage(parent: SparkConnectServerTab)
}
/** Generate basic stats of the Spark Connect Server */
- private def generateBasicStats(): Seq[Node] = {
+ private def generateBasicStats(sessionId: String): Seq[Node] = {
val timeSinceStart = System.currentTimeMillis() - startTime.getTime
+
+ Session ID: {sessionId}
+
Started at: {formatDate(startTime)}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index 76ce34a67e748..73bc1f2086aef 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.connect.SparkConnectTestUtils
+import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SessionHolder
trait FakeArrayParams extends Params {
@@ -379,4 +380,32 @@ class MLSuite extends MLHelper {
.map(_.getString)
.toArray sameElements Array("a", "b", "c"))
}
+
+ test("Memory limitation of MLCache works") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val memorySizeBytes = 1024 * 16
+ sessionHolder.session.conf
+ .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key, memorySizeBytes)
+ trainLogisticRegressionModel(sessionHolder)
+ assert(sessionHolder.mlCache.cachedModel.size() == 1)
+ assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
+ val modelSizeBytes = sessionHolder.mlCache.totalSizeBytes.get()
+ val maxNumModels = memorySizeBytes / modelSizeBytes.toInt
+
+ // All models will be kept if the total size is less than the memory limit.
+ for (i <- 1 until maxNumModels) {
+ trainLogisticRegressionModel(sessionHolder)
+ assert(sessionHolder.mlCache.cachedModel.size() == i + 1)
+ assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
+ assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes)
+ }
+
+ // Old models will be removed if new ones are added and the total size exceeds the memory limit.
+ for (_ <- 0 until 3) {
+ trainLogisticRegressionModel(sessionHolder)
+ assert(sessionHolder.mlCache.cachedModel.size() == maxNumModels)
+ assert(sessionHolder.mlCache.totalSizeBytes.get() > 0)
+ assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes)
+ }
+ }
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala
index 63d623cd2779b..82c8192fe0705 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectWithSessionExtensionSuite.scala
@@ -55,6 +55,9 @@ class SparkConnectWithSessionExtensionSuite extends SparkFunSuite {
override def parseQuery(sqlText: String): LogicalPlan =
delegate.parseQuery(sqlText)
+
+ override def parseRoutineParam(sqlText: String): StructType =
+ delegate.parseRoutineParam(sqlText)
}
test("Parse table name with test parser") {
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
index e681aa4726f8f..a158ca9fad8ce 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
@@ -193,6 +193,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
try {
val name = "classes/smallClassFile.class"
val artifactPath = inputFilePath.resolve("smallClassFile.class")
+ assume(artifactPath.toFile.exists)
addSingleChunkArtifact(handler, name, artifactPath)
handler.onCompleted()
val response = ThreadUtils.awaitResult(promise.future, 5.seconds)
@@ -217,6 +218,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
try {
val name = "jars/junitLargeJar.jar"
val artifactPath = inputFilePath.resolve("junitLargeJar.jar")
+ assume(artifactPath.toFile.exists)
addChunkedArtifact(handler, name, artifactPath)
handler.onCompleted()
val response = ThreadUtils.awaitResult(promise.future, 5.seconds)
@@ -250,6 +252,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
inputFilePath.resolve("junitLargeJar.jar"),
inputFilePath.resolve("smallClassFileDup.class"),
inputFilePath.resolve("smallJar.jar"))
+ artifactPaths.foreach(p => assume(p.toFile.exists))
addSingleChunkArtifact(handler, names.head, artifactPaths.head)
addChunkedArtifact(handler, names(1), artifactPaths(1))
@@ -281,6 +284,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
try {
val name = "classes/smallClassFile.class"
val artifactPath = inputFilePath.resolve("smallClassFile.class")
+ assume(artifactPath.toFile.exists)
val dataChunks = getDataChunks(artifactPath)
assert(dataChunks.size == 1)
val bytes = dataChunks.head
diff --git a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt
index bd2311634a5bc..1e3fc590644ad 100644
--- a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt
+++ b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-jdk21-results.txt
@@ -2,143 +2,143 @@
put rows
================================================================================================
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
putting 10000 rows (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------------------------------------------------
-In-memory 8 9 1 1.2 822.1 1.0X
-RocksDB (trackTotalNumberOfRows: true) 45 47 2 0.2 4455.4 0.2X
-RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1655.9 0.5X
+In-memory 8 9 1 1.2 815.6 1.0X
+RocksDB (trackTotalNumberOfRows: true) 46 47 2 0.2 4559.1 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 18 1 0.6 1678.7 0.5X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
putting 10000 rows (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------
-In-memory 8 9 1 1.2 805.6 1.0X
-RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4561.6 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1637.5 0.5X
+In-memory 8 9 1 1.3 798.1 1.0X
+RocksDB (trackTotalNumberOfRows: true) 47 48 2 0.2 4659.8 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1663.4 0.5X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
putting 10000 rows (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------
-In-memory 8 8 1 1.3 782.0 1.0X
-RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4537.4 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1633.0 0.5X
+In-memory 8 9 1 1.3 794.9 1.0X
+RocksDB (trackTotalNumberOfRows: true) 46 48 1 0.2 4625.7 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1660.7 0.5X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
putting 10000 rows (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------------------------------------------
-In-memory 8 8 1 1.3 783.3 1.0X
-RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4484.9 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1641.4 0.5X
+In-memory 8 8 1 1.3 788.6 1.0X
+RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4557.0 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1650.3 0.5X
================================================================================================
merge rows
================================================================================================
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
merging 10000 rows with 10 values per key (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------------------------------------
-RocksDB (trackTotalNumberOfRows: true) 565 579 7 0.0 56471.0 1.0X
-RocksDB (trackTotalNumberOfRows: false) 182 188 3 0.1 18161.0 3.1X
+RocksDB (trackTotalNumberOfRows: true) 574 585 6 0.0 57387.8 1.0X
+RocksDB (trackTotalNumberOfRows: false) 181 186 3 0.1 18065.2 3.2X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
merging 10000 rows with 10 values per key (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------------
-RocksDB (trackTotalNumberOfRows: true) 500 512 5 0.0 50023.6 1.0X
-RocksDB (trackTotalNumberOfRows: false) 183 188 3 0.1 18312.9 2.7X
+RocksDB (trackTotalNumberOfRows: true) 504 515 5 0.0 50382.4 1.0X
+RocksDB (trackTotalNumberOfRows: false) 179 185 3 0.1 17882.2 2.8X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
merging 10000 rows with 10 values per key (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------------
-RocksDB (trackTotalNumberOfRows: true) 436 447 5 0.0 43613.9 1.0X
-RocksDB (trackTotalNumberOfRows: false) 181 186 3 0.1 18065.5 2.4X
+RocksDB (trackTotalNumberOfRows: true) 442 455 6 0.0 44235.2 1.0X
+RocksDB (trackTotalNumberOfRows: false) 180 185 3 0.1 17971.5 2.5X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
merging 10000 rows with 10 values per key (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------------------------------
-RocksDB (trackTotalNumberOfRows: true) 413 425 5 0.0 41349.9 1.0X
-RocksDB (trackTotalNumberOfRows: false) 181 187 4 0.1 18075.6 2.3X
+RocksDB (trackTotalNumberOfRows: true) 424 436 5 0.0 42391.9 1.0X
+RocksDB (trackTotalNumberOfRows: false) 179 185 4 0.1 17923.5 2.4X
================================================================================================
delete rows
================================================================================================
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
trying to delete 10000 rows from 10000 rows(10000 rows are non-existing - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 0 0 0 26.8 37.3 1.0X
-RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4396.1 0.0X
-RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1522.7 0.0X
+In-memory 0 1 0 27.1 36.9 1.0X
+RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4470.0 0.0X
+RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1583.0 0.0X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
trying to delete 10000 rows from 10000 rows(5000 rows are non-existing - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 7 0 1.5 666.2 1.0X
-RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4392.1 0.2X
-RocksDB (trackTotalNumberOfRows: false) 15 16 0 0.7 1511.0 0.4X
+In-memory 7 7 0 1.5 651.4 1.0X
+RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4580.3 0.1X
+RocksDB (trackTotalNumberOfRows: false) 16 17 0 0.6 1582.7 0.4X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
trying to delete 10000 rows from 10000 rows(1000 rows are non-existing - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 8 1 1.4 714.3 1.0X
-RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4362.1 0.2X
-RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1520.6 0.5X
+In-memory 7 8 0 1.4 713.7 1.0X
+RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4538.6 0.2X
+RocksDB (trackTotalNumberOfRows: false) 16 16 0 0.6 1579.3 0.5X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
trying to delete 10000 rows from 10000 rows(0 rows are non-existing - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 8 0 1.4 725.8 1.0X
-RocksDB (trackTotalNumberOfRows: true) 43 45 1 0.2 4310.1 0.2X
-RocksDB (trackTotalNumberOfRows: false) 15 16 1 0.7 1528.4 0.5X
+In-memory 7 8 0 1.4 716.9 1.0X
+RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4459.8 0.2X
+RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1580.7 0.5X
================================================================================================
evict rows
================================================================================================
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
evicting 10000 rows (maxTimestampToEvictInMillis: 9999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 8 0 1.4 715.6 1.0X
-RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4386.8 0.2X
-RocksDB (trackTotalNumberOfRows: false) 17 17 0 0.6 1686.7 0.4X
+In-memory 7 7 0 1.5 689.5 1.0X
+RocksDB (trackTotalNumberOfRows: true) 44 45 1 0.2 4424.0 0.2X
+RocksDB (trackTotalNumberOfRows: false) 18 18 0 0.6 1784.2 0.4X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
evicting 5000 rows (maxTimestampToEvictInMillis: 4999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 7 0 1.5 667.7 1.0X
-RocksDB (trackTotalNumberOfRows: true) 23 24 1 0.4 2292.5 0.3X
-RocksDB (trackTotalNumberOfRows: false) 10 10 0 1.0 994.3 0.7X
+In-memory 6 7 0 1.5 650.0 1.0X
+RocksDB (trackTotalNumberOfRows: true) 23 24 1 0.4 2347.8 0.3X
+RocksDB (trackTotalNumberOfRows: false) 10 11 0 1.0 1037.1 0.6X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
evicting 1000 rows (maxTimestampToEvictInMillis: 999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-----------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 6 6 0 1.7 603.8 1.0X
-RocksDB (trackTotalNumberOfRows: true) 7 8 0 1.3 749.5 0.8X
-RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.1 482.2 1.3X
+In-memory 6 6 0 1.7 585.4 1.0X
+RocksDB (trackTotalNumberOfRows: true) 8 8 0 1.3 766.5 0.8X
+RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.0 503.2 1.2X
-OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 21.0.6+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
evicting 0 rows (maxTimestampToEvictInMillis: -1) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 0 0 0 23.7 42.1 1.0X
-RocksDB (trackTotalNumberOfRows: true) 3 4 0 2.9 345.1 0.1X
-RocksDB (trackTotalNumberOfRows: false) 3 4 0 2.9 344.6 0.1X
+In-memory 0 0 0 25.0 40.1 1.0X
+RocksDB (trackTotalNumberOfRows: true) 4 4 0 2.8 359.1 0.1X
+RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.8 359.9 0.1X
diff --git a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt
index 4a7c21d01c80e..a3688c16f100a 100644
--- a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt
+++ b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt
@@ -2,143 +2,143 @@
put rows
================================================================================================
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
putting 10000 rows (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------------------------------------------------
-In-memory 8 10 1 1.2 842.4 1.0X
-RocksDB (trackTotalNumberOfRows: true) 45 47 2 0.2 4529.0 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1635.3 0.5X
+In-memory 8 9 1 1.2 816.3 1.0X
+RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4514.1 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 18 1 0.6 1682.7 0.5X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
putting 10000 rows (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------
-In-memory 8 10 1 1.2 831.7 1.0X
-RocksDB (trackTotalNumberOfRows: true) 47 48 1 0.2 4662.3 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1625.5 0.5X
+In-memory 8 10 1 1.2 811.7 1.0X
+RocksDB (trackTotalNumberOfRows: true) 47 49 1 0.2 4694.9 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 18 1 0.6 1680.2 0.5X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
putting 10000 rows (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------
-In-memory 8 9 1 1.2 802.0 1.0X
-RocksDB (trackTotalNumberOfRows: true) 46 48 1 0.2 4634.4 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1616.5 0.5X
+In-memory 8 9 1 1.3 786.5 1.0X
+RocksDB (trackTotalNumberOfRows: true) 47 48 1 0.2 4679.7 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 18 1 0.6 1650.0 0.5X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
putting 10000 rows (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
---------------------------------------------------------------------------------------------------------------------------------
-In-memory 8 9 1 1.2 828.2 1.0X
-RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4593.4 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1596.2 0.5X
+In-memory 8 8 1 1.3 778.0 1.0X
+RocksDB (trackTotalNumberOfRows: true) 46 48 1 0.2 4629.4 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1664.9 0.5X
================================================================================================
merge rows
================================================================================================
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
merging 10000 rows with 10 values per key (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------------------------------------
-RocksDB (trackTotalNumberOfRows: true) 574 593 7 0.0 57382.2 1.0X
-RocksDB (trackTotalNumberOfRows: false) 186 191 3 0.1 18572.6 3.1X
+RocksDB (trackTotalNumberOfRows: true) 570 585 6 0.0 56996.2 1.0X
+RocksDB (trackTotalNumberOfRows: false) 184 190 3 0.1 18411.4 3.1X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
merging 10000 rows with 10 values per key (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------------
-RocksDB (trackTotalNumberOfRows: true) 502 513 5 0.0 50183.7 1.0X
-RocksDB (trackTotalNumberOfRows: false) 185 191 3 0.1 18542.0 2.7X
+RocksDB (trackTotalNumberOfRows: true) 493 505 5 0.0 49327.2 1.0X
+RocksDB (trackTotalNumberOfRows: false) 181 188 3 0.1 18140.8 2.7X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
merging 10000 rows with 10 values per key (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------------
-RocksDB (trackTotalNumberOfRows: true) 439 453 6 0.0 43896.3 1.0X
-RocksDB (trackTotalNumberOfRows: false) 184 190 3 0.1 18384.9 2.4X
+RocksDB (trackTotalNumberOfRows: true) 435 448 5 0.0 43484.3 1.0X
+RocksDB (trackTotalNumberOfRows: false) 183 188 3 0.1 18289.1 2.4X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
merging 10000 rows with 10 values per key (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
--------------------------------------------------------------------------------------------------------------------------------------------------------
-RocksDB (trackTotalNumberOfRows: true) 421 433 5 0.0 42057.9 1.0X
-RocksDB (trackTotalNumberOfRows: false) 184 192 3 0.1 18421.4 2.3X
+RocksDB (trackTotalNumberOfRows: true) 416 432 5 0.0 41606.2 1.0X
+RocksDB (trackTotalNumberOfRows: false) 183 189 3 0.1 18282.2 2.3X
================================================================================================
delete rows
================================================================================================
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
trying to delete 10000 rows from 10000 rows(10000 rows are non-existing - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 0 1 0 26.3 38.0 1.0X
-RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4510.4 0.0X
-RocksDB (trackTotalNumberOfRows: false) 16 17 0 0.6 1585.4 0.0X
+In-memory 0 1 0 26.6 37.7 1.0X
+RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4514.1 0.0X
+RocksDB (trackTotalNumberOfRows: false) 16 17 0 0.6 1587.8 0.0X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
trying to delete 10000 rows from 10000 rows(5000 rows are non-existing - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 7 0 1.5 673.9 1.0X
-RocksDB (trackTotalNumberOfRows: true) 46 47 1 0.2 4566.1 0.1X
-RocksDB (trackTotalNumberOfRows: false) 16 16 0 0.6 1572.0 0.4X
+In-memory 6 7 1 1.6 644.9 1.0X
+RocksDB (trackTotalNumberOfRows: true) 45 47 1 0.2 4524.6 0.1X
+RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1579.1 0.4X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
trying to delete 10000 rows from 10000 rows(1000 rows are non-existing - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 8 1 1.4 725.8 1.0X
-RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4481.0 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 16 0 0.6 1582.0 0.5X
+In-memory 7 8 1 1.4 698.2 1.0X
+RocksDB (trackTotalNumberOfRows: true) 45 46 1 0.2 4481.1 0.2X
+RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1585.3 0.4X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
trying to delete 10000 rows from 10000 rows(0 rows are non-existing - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 8 1 1.4 736.3 1.0X
-RocksDB (trackTotalNumberOfRows: true) 44 46 1 0.2 4449.2 0.2X
-RocksDB (trackTotalNumberOfRows: false) 16 16 1 0.6 1570.2 0.5X
+In-memory 7 8 1 1.4 707.0 1.0X
+RocksDB (trackTotalNumberOfRows: true) 43 45 1 0.2 4326.6 0.2X
+RocksDB (trackTotalNumberOfRows: false) 16 17 1 0.6 1560.6 0.5X
================================================================================================
evict rows
================================================================================================
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
evicting 10000 rows (maxTimestampToEvictInMillis: 9999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 8 0 1.4 719.2 1.0X
-RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4313.7 0.2X
-RocksDB (trackTotalNumberOfRows: false) 17 17 1 0.6 1652.9 0.4X
+In-memory 7 7 0 1.4 693.7 1.0X
+RocksDB (trackTotalNumberOfRows: true) 43 44 1 0.2 4285.3 0.2X
+RocksDB (trackTotalNumberOfRows: false) 17 18 0 0.6 1726.3 0.4X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
evicting 5000 rows (maxTimestampToEvictInMillis: 4999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 7 7 0 1.5 670.7 1.0X
-RocksDB (trackTotalNumberOfRows: true) 23 24 1 0.4 2332.2 0.3X
-RocksDB (trackTotalNumberOfRows: false) 10 11 0 1.0 1026.8 0.7X
+In-memory 6 7 0 1.5 646.3 1.0X
+RocksDB (trackTotalNumberOfRows: true) 24 24 0 0.4 2351.2 0.3X
+RocksDB (trackTotalNumberOfRows: false) 11 11 0 0.9 1062.9 0.6X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
evicting 1000 rows (maxTimestampToEvictInMillis: 999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-----------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 6 7 0 1.6 610.3 1.0X
-RocksDB (trackTotalNumberOfRows: true) 8 8 0 1.3 767.9 0.8X
-RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.0 507.6 1.2X
+In-memory 6 6 0 1.7 587.7 1.0X
+RocksDB (trackTotalNumberOfRows: true) 8 8 0 1.3 784.7 0.7X
+RocksDB (trackTotalNumberOfRows: false) 5 6 0 1.9 529.1 1.1X
-OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1021-azure
+OpenJDK 64-Bit Server VM 17.0.14+7-LTS on Linux 6.8.0-1020-azure
AMD EPYC 7763 64-Core Processor
evicting 0 rows (maxTimestampToEvictInMillis: -1) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------------------------------
-In-memory 0 0 0 23.1 43.3 1.0X
-RocksDB (trackTotalNumberOfRows: true) 4 4 0 2.7 370.8 0.1X
-RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.7 371.8 0.1X
+In-memory 0 0 0 23.2 43.2 1.0X
+RocksDB (trackTotalNumberOfRows: true) 4 4 0 2.6 387.5 0.1X
+RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.6 389.4 0.1X
diff --git a/sql/core/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java b/sql/core/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java
index c927991425cd5..e6d89eaf3e39c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java
@@ -18,33 +18,31 @@
package org.apache.spark.sql.avro;
import java.util.Arrays;
+import java.util.EnumMap;
import java.util.Locale;
-import java.util.Map;
import java.util.stream.Collectors;
-import org.apache.avro.file.*;
+import org.apache.avro.file.DataFileConstants;
/**
* A mapper class from Spark supported avro compression codecs to avro compression codecs.
*/
public enum AvroCompressionCodec {
- UNCOMPRESSED(DataFileConstants.NULL_CODEC, false, -1),
- DEFLATE(DataFileConstants.DEFLATE_CODEC, true, CodecFactory.DEFAULT_DEFLATE_LEVEL),
- SNAPPY(DataFileConstants.SNAPPY_CODEC, false, -1),
- BZIP2(DataFileConstants.BZIP2_CODEC, false, -1),
- XZ(DataFileConstants.XZ_CODEC, true, CodecFactory.DEFAULT_XZ_LEVEL),
- ZSTANDARD(DataFileConstants.ZSTANDARD_CODEC, true, CodecFactory.DEFAULT_ZSTANDARD_LEVEL);
+ UNCOMPRESSED(DataFileConstants.NULL_CODEC, false),
+ DEFLATE(DataFileConstants.DEFLATE_CODEC, true),
+ SNAPPY(DataFileConstants.SNAPPY_CODEC, false),
+ BZIP2(DataFileConstants.BZIP2_CODEC, false),
+ XZ(DataFileConstants.XZ_CODEC, true),
+ ZSTANDARD(DataFileConstants.ZSTANDARD_CODEC, true);
private final String codecName;
private final boolean supportCompressionLevel;
- private final int defaultCompressionLevel;
AvroCompressionCodec(
String codecName,
- boolean supportCompressionLevel, int defaultCompressionLevel) {
+ boolean supportCompressionLevel) {
this.codecName = codecName;
this.supportCompressionLevel = supportCompressionLevel;
- this.defaultCompressionLevel = defaultCompressionLevel;
}
public String getCodecName() {
@@ -55,19 +53,19 @@ public boolean getSupportCompressionLevel() {
return this.supportCompressionLevel;
}
- public int getDefaultCompressionLevel() {
- return this.defaultCompressionLevel;
+ public static AvroCompressionCodec fromString(String s) {
+ return AvroCompressionCodec.valueOf(s.toUpperCase(Locale.ROOT));
}
- private static final Map codecNameMap =
+ private static final EnumMap codecNameMap =
Arrays.stream(AvroCompressionCodec.values()).collect(
- Collectors.toMap(codec -> codec.name(), codec -> codec.name().toLowerCase(Locale.ROOT)));
+ Collectors.toMap(
+ codec -> codec,
+ codec -> codec.name().toLowerCase(Locale.ROOT),
+ (oldValue, newValue) -> oldValue,
+ () -> new EnumMap<>(AvroCompressionCodec.class)));
public String lowerCaseName() {
- return codecNameMap.get(this.name());
- }
-
- public static AvroCompressionCodec fromString(String s) {
- return AvroCompressionCodec.valueOf(s.toUpperCase(Locale.ROOT));
+ return codecNameMap.get(this);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcCompressionCodec.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcCompressionCodec.java
index 55bc9d04b4400..abea45122307e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcCompressionCodec.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcCompressionCodec.java
@@ -18,8 +18,8 @@
package org.apache.spark.sql.execution.datasources.orc;
import java.util.Arrays;
+import java.util.EnumMap;
import java.util.Locale;
-import java.util.Map;
import java.util.stream.Collectors;
import org.apache.orc.CompressionKind;
@@ -47,11 +47,15 @@ public CompressionKind getCompressionKind() {
return this.compressionKind;
}
- public static final Map codecNameMap =
+ private static final EnumMap codecNameMap =
Arrays.stream(OrcCompressionCodec.values()).collect(
- Collectors.toMap(codec -> codec.name(), codec -> codec.name().toLowerCase(Locale.ROOT)));
+ Collectors.toMap(
+ codec -> codec,
+ codec -> codec.name().toLowerCase(Locale.ROOT),
+ (oldValue, newValue) -> oldValue,
+ () -> new EnumMap<>(OrcCompressionCodec.class)));
public String lowerCaseName() {
- return codecNameMap.get(this.name());
+ return codecNameMap.get(this);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodec.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodec.java
index 32d9701bdbb21..8dfcc6c7c60c8 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodec.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodec.java
@@ -18,9 +18,9 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.util.Arrays;
+import java.util.EnumMap;
import java.util.List;
import java.util.Locale;
-import java.util.Map;
import java.util.stream.Collectors;
import org.apache.parquet.hadoop.metadata.CompressionCodecName;
@@ -53,12 +53,16 @@ public static ParquetCompressionCodec fromString(String s) {
return ParquetCompressionCodec.valueOf(s.toUpperCase(Locale.ROOT));
}
- private static final Map codecNameMap =
+ private static final EnumMap codecNameMap =
Arrays.stream(ParquetCompressionCodec.values()).collect(
- Collectors.toMap(codec -> codec.name(), codec -> codec.name().toLowerCase(Locale.ROOT)));
+ Collectors.toMap(
+ codec -> codec,
+ codec -> codec.name().toLowerCase(Locale.ROOT),
+ (oldValue, newValue) -> oldValue,
+ () -> new EnumMap<>(ParquetCompressionCodec.class)));
public String lowerCaseName() {
- return codecNameMap.get(this.name());
+ return codecNameMap.get(this);
}
public static final List availableCodecs =
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
index 70f806ba14f03..889f11e119730 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
@@ -158,6 +158,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
return new LongUpdater();
} else if (canReadAsDecimal(descriptor, sparkType)) {
return new LongToDecimalUpdater(descriptor, (DecimalType) sparkType);
+ } else if (sparkType instanceof TimeType) {
+ return new LongUpdater();
}
}
case FLOAT -> {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 1882d990bef55..da52cdf5c835c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -619,7 +619,8 @@ protected void reserveInternal(int newCapacity) {
this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L);
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type) || type instanceof TimestampType ||
- type instanceof TimestampNTZType || type instanceof DayTimeIntervalType) {
+ type instanceof TimestampNTZType || type instanceof DayTimeIntervalType ||
+ type instanceof TimeType) {
this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L);
} else if (childColumns != null) {
// Nothing to store.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 1908b511269a6..fd3b07e3e2171 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -613,7 +613,8 @@ protected void reserveInternal(int newCapacity) {
}
} else if (type instanceof LongType ||
type instanceof TimestampType ||type instanceof TimestampNTZType ||
- DecimalType.is64BitDecimalType(type) || type instanceof DayTimeIntervalType) {
+ DecimalType.is64BitDecimalType(type) || type instanceof DayTimeIntervalType ||
+ type instanceof TimeType) {
if (longData == null || longData.length < newCapacity) {
long[] newData = new long[newCapacity];
if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity);
diff --git a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
index 1374bd100a2fe..ce83c285410b5 100644
--- a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
+++ b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
@@ -48,6 +48,13 @@ message StateResponseWithStringTypeVal {
string value = 3;
}
+message StateResponseWithListGet {
+ int32 statusCode = 1;
+ string errorMessage = 2;
+ repeated bytes value = 3;
+ bool requireNextFetch = 4;
+}
+
message StatefulProcessorCall {
oneof method {
SetHandleState setHandleState = 1;
@@ -197,6 +204,8 @@ message ListStateGet {
}
message ListStatePut {
+ repeated bytes value = 1;
+ bool fetchWithArrow = 2;
}
message AppendValue {
@@ -204,6 +213,8 @@ message AppendValue {
}
message AppendList {
+ repeated bytes value = 1;
+ bool fetchWithArrow = 2;
}
message GetValue {
diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css
index d6a498e93872c..0329579406814 100644
--- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css
+++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css
@@ -15,35 +15,35 @@
* limitations under the License.
*/
-#plan-viz-graph .label {
+svg g.label {
font-size: 0.85rem;
font-weight: normal;
text-shadow: none;
color: #333;
}
-#plan-viz-graph svg g.cluster rect {
+svg g.cluster rect {
fill: #A0DFFF;
stroke: #3EC0FF;
stroke-width: 1px;
}
-#plan-viz-graph svg g.node rect {
+svg g.node rect {
fill: #C3EBFF;
stroke: #3EC0FF;
stroke-width: 1px;
}
/* Highlight the SparkPlan node name */
-#plan-viz-graph svg text :first-child:not(.stageId-and-taskId-metrics) {
+svg text :first-child:not(.stageId-and-taskId-metrics) {
font-weight: bold;
}
-#plan-viz-graph svg text {
+svg text {
fill: #333;
}
-#plan-viz-graph svg path {
+svg path {
stroke: #444;
stroke-width: 1.5px;
}
@@ -58,19 +58,19 @@
word-wrap: break-word;
}
-#plan-viz-graph svg g.node rect.selected {
+svg g.node rect.selected {
fill: #E25A1CFF;
stroke: #317EACFF;
stroke-width: 2px;
}
-#plan-viz-graph svg g.node rect.linked {
+svg g.node rect.linked {
fill: #FFC106FF;
stroke: #317EACFF;
stroke-width: 2px;
}
-#plan-viz-graph svg path.linked {
+svg path.linked {
fill: #317EACFF;
stroke: #317EACFF;
stroke-width: 2px;
diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js
index d4cc45a1639ab..faf30d5d54225 100644
--- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js
+++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js
@@ -22,13 +22,10 @@ var PlanVizConstants = {
svgMarginY: 16
};
-/* eslint-disable no-unused-vars */
function shouldRenderPlanViz() {
return planVizContainer().selectAll("svg").empty();
}
-/* eslint-enable no-unused-vars */
-/* eslint-disable no-unused-vars */
function renderPlanViz() {
var svg = planVizContainer().append("svg");
var metadata = d3.select("#plan-viz-metadata");
@@ -52,7 +49,6 @@ function renderPlanViz() {
resizeSvg(svg);
postprocessForAdditionalMetrics();
}
-/* eslint-enable no-unused-vars */
/* -------------------- *
* | Helper functions | *
@@ -312,3 +308,49 @@ function collectLinks(map, key, value) {
}
map.get(key).add(value);
}
+
+function downloadPlanBlob(b, ext) {
+ const link = document.createElement("a");
+ link.href = URL.createObjectURL(b);
+ link.download = `plan.${ext}`;
+ link.click();
+}
+
+document.getElementById("plan-viz-download-btn").addEventListener("click", async function () {
+ const format = document.getElementById("plan-viz-format-select").value;
+ let blob;
+ if (format === "svg") {
+ const svg = planVizContainer().select("svg").node().cloneNode(true);
+ let css = "";
+ try {
+ css = await fetch("/static/sql/spark-sql-viz.css").then((resp) => resp.text());
+ } catch (e) {
+ console.error("Failed to fetch CSS for SVG download", e);
+ }
+ d3.select(svg).insert("style", ":first-child").text(css);
+ const svgData = new XMLSerializer().serializeToString(svg);
+ blob = new Blob([svgData], { type: "image/svg+xml" });
+ } else if (format === "dot") {
+ const dot = d3.select("#plan-viz-metadata .dot-file").text().trim();
+ blob = new Blob([dot], { type: "text/plain" });
+ } else if (format === "txt") {
+ const txt = d3.select("#physical-plan-details pre").text().trim();
+ blob = new Blob([txt], { type: "text/plain" });
+ } else {
+ return;
+ }
+ downloadPlanBlob(blob, format);
+});
+
+/* eslint-disable no-unused-vars */
+function clickPhysicalPlanDetails() {
+/* eslint-enable no-unused-vars */
+ $('#physical-plan-details').toggle();
+ $('#physical-plan-details-arrow').toggleClass('arrow-open').toggleClass('arrow-closed');
+}
+
+document.addEventListener("DOMContentLoaded", function () {
+ if (shouldRenderPlanViz()) {
+ renderPlanViz();
+ }
+});
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 374d38db371a2..40779c66600fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.api.python
import java.io.InputStream
-import java.net.Socket
-import java.nio.channels.Channels
+import java.nio.channels.{Channels, SocketChannel}
import net.razorvine.pickle.{Pickler, Unpickler}
@@ -197,8 +196,8 @@ private[sql] object PythonSQLUtils extends Logging {
private[spark] class ArrowIteratorServer
extends SocketAuthServer[Iterator[Array[Byte]]]("pyspark-arrow-batches-server") {
- def handleConnection(sock: Socket): Iterator[Array[Byte]] = {
- val in = sock.getInputStream()
+ def handleConnection(sock: SocketChannel): Iterator[Array[Byte]] = {
+ val in = Channels.newInputStream(sock)
val dechunkedInput: InputStream = new DechunkedInputStream(in)
// Create array to consume iterator so that we can safely close the file
ArrowConverters.getBatchesFromStream(Channels.newChannel(dechunkedInput)).toArray.iterator
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index ac20614553ca2..65fafb5a34c6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.avro
import java.math.BigDecimal
import java.nio.ByteBuffer
+import java.time.ZoneOffset
import scala.jdk.CollectionConverters._
@@ -57,7 +58,7 @@ private[sql] class AvroDeserializer(
def this(
rootAvroType: Schema,
rootCatalystType: DataType,
- datetimeRebaseMode: String,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
useStableIdForUnionType: Boolean,
stableIdPrefixForUnionType: String,
recursiveFieldMaxDepth: Int) = {
@@ -65,7 +66,7 @@ private[sql] class AvroDeserializer(
rootAvroType,
rootCatalystType,
positionalFieldMatch = false,
- RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)),
+ RebaseSpec(datetimeRebaseMode),
new NoopFilters,
useStableIdForUnionType,
stableIdPrefixForUnionType,
@@ -159,6 +160,12 @@ private[sql] class AvroDeserializer(
case (INT, DateType) => (updater, ordinal, value) =>
updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int]))
+ case (INT, TimestampNTZType) if avroType.getLogicalType.isInstanceOf[LogicalTypes.Date] =>
+ (updater, ordinal, value) =>
+ val days = dateRebaseFunc(value.asInstanceOf[Int])
+ val micros = DateTimeUtils.daysToMicros(days, ZoneOffset.UTC)
+ updater.setLong(ordinal, micros)
+
case (LONG, dt: DatetimeType)
if preventReadingIncorrectType && realDataType.isInstanceOf[DayTimeIntervalType] =>
throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
index d571b3ed6050e..ab3607d1bd7a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.avro
import java.net.URI
+import java.util.HashMap
import org.apache.avro.Schema
import org.apache.hadoop.conf.Configuration
@@ -28,7 +29,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode}
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
/**
* Options for Avro Reader and Writer stored in case insensitive manner.
@@ -128,8 +129,8 @@ private[sql] class AvroOptions(
/**
* The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads.
*/
- val datetimeRebaseModeInRead: String = parameters
- .get(DATETIME_REBASE_MODE)
+ val datetimeRebaseModeInRead: LegacyBehaviorPolicy.Value = parameters
+ .get(DATETIME_REBASE_MODE).map(LegacyBehaviorPolicy.withName)
.getOrElse(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ))
val useStableIdForUnionType: Boolean =
@@ -146,6 +147,37 @@ private[sql] class AvroOptions(
RECURSIVE_FIELD_MAX_DEPTH,
s"Should not be greater than $RECURSIVE_FIELD_MAX_DEPTH_LIMIT.")
}
+
+ /**
+ * [[hadoop.conf.Configuration]] is not comparable so we turn it into a map for [[equals]] and
+ * [[hashCode]].
+ */
+ @transient private lazy val comparableConf = {
+ val iter = conf.iterator()
+ val result = new HashMap[String, String]
+ while (iter.hasNext()) {
+ val entry = iter.next()
+ result.put(entry.getKey(), entry.getValue())
+ }
+ result
+ }
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case that: AvroOptions =>
+ this.parameters == that.parameters &&
+ this.comparableConf == that.comparableConf
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = {
+ val prime = 31
+ var result = 1
+ result = prime * result + parameters.hashCode
+ result = prime * result + comparableConf.hashCode
+ result
+ }
}
private[sql] object AvroOptions extends DataSourceOptions {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
index c4aaacf515453..767216b819926 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
@@ -44,8 +44,7 @@ private[avro] class AvroOutputWriter(
avroSchema: Schema) extends OutputWriter {
// Whether to rebase datetimes from Gregorian to Julian calendar in write
- private val datetimeRebaseMode = LegacyBehaviorPolicy.withName(
- SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE))
+ private val datetimeRebaseMode = SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE)
// The input rows will never be null.
private lazy val serializer = new AvroSerializer(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 1d83a46a278f7..402bab666948d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -53,7 +53,7 @@ private[sql] class AvroSerializer(
def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = {
this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = false,
- LegacyBehaviorPolicy.withName(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE)))
+ SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE))
}
def serialize(catalystData: Any): Any = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
index 3299b34bcc933..b388c98ffcb1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -83,6 +83,7 @@ private[sql] object AvroUtils extends Logging {
def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: VariantType => false
+ case _: TimeType => false
case _: AtomicType => true
case st: StructType => st.forall { f => supportsDataType(f.dataType) }
@@ -121,18 +122,20 @@ private[sql] object AvroUtils extends Logging {
jobConf.setBoolean("mapreduce.output.fileoutputformat.compress", true)
jobConf.set(AvroJob.CONF_OUTPUT_CODEC, compressed.getCodecName)
if (compressed.getSupportCompressionLevel) {
- val level = sqlConf.getConfString(s"spark.sql.avro.$codecName.level",
- compressed.getDefaultCompressionLevel.toString)
- logInfo(log"Compressing Avro output using the ${MDC(CODEC_NAME, codecName)} codec " +
- log"at level ${MDC(CODEC_LEVEL, level)}")
- val s = if (compressed == ZSTANDARD) {
- val bufferPoolEnabled = sqlConf.getConf(SQLConf.AVRO_ZSTANDARD_BUFFER_POOL_ENABLED)
- jobConf.setBoolean(AvroOutputFormat.ZSTD_BUFFERPOOL_KEY, bufferPoolEnabled)
- "zstd"
- } else {
- codecName
+ val levelAndCodecName = compressed match {
+ case DEFLATE => Some(sqlConf.getConf(SQLConf.AVRO_DEFLATE_LEVEL), codecName)
+ case XZ => Some(sqlConf.getConf(SQLConf.AVRO_XZ_LEVEL), codecName)
+ case ZSTANDARD =>
+ jobConf.setBoolean(AvroOutputFormat.ZSTD_BUFFERPOOL_KEY,
+ sqlConf.getConf(SQLConf.AVRO_ZSTANDARD_BUFFER_POOL_ENABLED))
+ Some(sqlConf.getConf(SQLConf.AVRO_ZSTANDARD_LEVEL), "zstd")
+ case _ => None
+ }
+ levelAndCodecName.foreach { case (level, mapredCodecName) =>
+ logInfo(log"Compressing Avro output using the ${MDC(CODEC_NAME, codecName)} " +
+ log"codec at level ${MDC(CODEC_LEVEL, level)}")
+ jobConf.setInt(s"avro.mapred.$mapredCodecName.level", level.toInt)
}
- jobConf.setInt(s"avro.mapred.$s.level", level.toInt)
} else {
logInfo(log"Compressing Avro output using the ${MDC(CODEC_NAME, codecName)} codec")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataSource.scala
index 249ea6e6d04cb..07208ca7760db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataSource.scala
@@ -140,6 +140,6 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
paths = finalPaths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
- options = finalOptions.originalMap).resolveRelation())
+ options = finalOptions.originalMap).resolveRelation(readOnly = true))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala
index aff65496b763b..1e028d2046eb5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.classic
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
+import org.apache.spark.SparkThrowable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalog
import org.apache.spark.sql.catalog.{CatalogMetadata, Column, Database, Function, Table}
@@ -30,13 +31,13 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowTables, UnresolvedTableSpec, View}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogManager, SupportsNamespaces, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, NamespaceHelper, TransformHelper}
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.command.ShowTablesCommand
+import org.apache.spark.sql.execution.command.{ShowNamespacesCommand, ShowTablesCommand}
import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.connector.V1Function
@@ -105,10 +106,10 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
listDatabasesInternal(Some(pattern))
private def listDatabasesInternal(patternOpt: Option[String]): Dataset[Database] = {
- val plan = ShowNamespaces(UnresolvedNamespace(Nil), patternOpt)
+ val plan = ShowNamespacesCommand(UnresolvedNamespace(Nil), patternOpt)
val qe = sparkSession.sessionState.executePlan(plan)
val catalog = qe.analyzed.collectFirst {
- case ShowNamespaces(r: ResolvedNamespace, _, _) => r.catalog
+ case ShowNamespacesCommand(r: ResolvedNamespace, _, _) => r.catalog
}.get
val databases = qe.toRdd.collect().map { row =>
// dbName can either be a quoted identifier (single or multi part) or an unquoted single part
@@ -167,18 +168,25 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
private[sql] def resolveTable(row: InternalRow, catalogName: String): Option[Table] = {
val tableName = row.getString(1)
val namespaceName = row.getString(0)
- val isTemp = row.getBoolean(2)
+ val isTempView = row.getBoolean(2)
+ val ns = if (isTempView) {
+ if (namespaceName.isEmpty) Nil else Seq(namespaceName)
+ } else {
+ parseIdent(namespaceName)
+ }
+ val nameParts = if (isTempView) {
+ // Temp views do not belong to any catalog. We shouldn't prepend the catalog name here.
+ ns :+ tableName
+ } else {
+ catalogName +: ns :+ tableName
+ }
try {
- if (isTemp) {
- // Temp views do not belong to any catalog. We shouldn't prepend the catalog name here.
- val ns = if (namespaceName.isEmpty) Nil else Seq(namespaceName)
- Some(makeTable(ns :+ tableName))
- } else {
- val ns = parseIdent(namespaceName)
- try {
- Some(makeTable(catalogName +: ns :+ tableName))
- } catch {
- case e: AnalysisException if e.getCondition == "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE" =>
+ Some(makeTable(nameParts))
+ } catch {
+ case e: SparkThrowable with Throwable =>
+ Catalog.ListTable.ERROR_HANDLING_RULES.get(e.getCondition) match {
+ case Some(Catalog.ListTable.Skip) => None
+ case Some(Catalog.ListTable.ReturnPartialResults) if !isTempView =>
Some(new Table(
name = tableName,
catalog = catalogName,
@@ -187,10 +195,8 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
tableType = null,
isTemporary = false
))
+ case _ => throw e
}
- }
- } catch {
- case e: AnalysisException if e.getCondition == "TABLE_OR_VIEW_NOT_FOUND" => None
}
}
@@ -688,7 +694,8 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
comment = { if (description.isEmpty) None else Some(description) },
collation = None,
serde = None,
- external = tableType == CatalogTableType.EXTERNAL)
+ external = tableType == CatalogTableType.EXTERNAL,
+ constraints = Seq.empty)
val plan = CreateTable(
name = UnresolvedIdentifier(ident),
@@ -957,4 +964,18 @@ private[sql] object Catalog {
}
private val FUNCTION_EXISTS_COMMAND_NAME = "Catalog.functionExists"
+
+ private object ListTable {
+
+ sealed trait ErrorHandlingAction
+
+ case object Skip extends ErrorHandlingAction
+
+ case object ReturnPartialResults extends ErrorHandlingAction
+
+ val ERROR_HANDLING_RULES: Map[String, ErrorHandlingAction] = Map(
+ "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE" -> ReturnPartialResults,
+ "TABLE_OR_VIEW_NOT_FOUND" -> Skip
+ )
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameReader.scala
index 489ac31e5291d..bc01517e1c6ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameReader.scala
@@ -27,6 +27,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql
import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.expressions.ExprUtils
@@ -35,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical.UnresolvedDataSource
import org.apache.spark.sql.catalyst.util.FailureSafeParser
import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, XmlOptions}
import org.apache.spark.sql.classic.ClassicConversions._
-import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.JsonUtils.checkJsonSchema
@@ -335,12 +335,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession)
override def textFile(paths: String*): Dataset[String] = super.textFile(paths: _*)
/** @inheritdoc */
- override protected def validateSingleVariantColumn(): Unit = {
- if (extraOptions.get(JSONOptions.SINGLE_VARIANT_COLUMN).isDefined &&
- userSpecifiedSchema.isDefined) {
- throw QueryCompilationErrors.invalidSingleVariantColumn()
- }
- }
+ override protected def validateSingleVariantColumn(): Unit =
+ DataSourceOptions.validateSingleVariantColumn(extraOptions, userSpecifiedSchema)
override protected def validateJsonSchema(): Unit =
userSpecifiedSchema.foreach(checkJsonSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
index b423c89fff3db..501b4985128dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
@@ -213,7 +213,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
comment = extraOptions.get(TableCatalog.PROP_COMMENT),
collation = extraOptions.get(TableCatalog.PROP_COLLATION),
serde = None,
- external = false)
+ external = false,
+ constraints = Seq.empty)
runCommand(df.sparkSession) {
CreateTableAsSelect(
UnresolvedIdentifier(
@@ -478,7 +479,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
comment = extraOptions.get(TableCatalog.PROP_COMMENT),
collation = extraOptions.get(TableCatalog.PROP_COLLATION),
serde = None,
- external = false)
+ external = false,
+ constraints = Seq.empty)
ReplaceTableAsSelect(
UnresolvedIdentifier(nameParts),
partitioningAsV2,
@@ -499,7 +501,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
comment = extraOptions.get(TableCatalog.PROP_COMMENT),
collation = extraOptions.get(TableCatalog.PROP_COLLATION),
serde = None,
- external = false)
+ external = false,
+ constraints = Seq.empty)
CreateTableAsSelect(
UnresolvedIdentifier(nameParts),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
index 5dee09175839c..c6eacfe8f1ed9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
@@ -169,7 +169,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
comment = None,
collation = None,
serde = None,
- external = false)
+ external = false,
+ constraints = Seq.empty)
}
/** @inheritdoc */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala
index 96e8755577542..471c5feadaabc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala
@@ -175,7 +175,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D
None,
None,
None,
- external = false)
+ external = false,
+ constraints = Seq.empty)
val cmd = CreateTable(
UnresolvedIdentifier(originalMultipartIdentifier),
ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index 9d1ca3ce5fa0a..5c3ebb32b36a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder, StructEncoder}
-import org.apache.spark.sql.catalyst.expressions.{ScalarSubquery => ScalarSubqueryExpr, _}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
@@ -69,8 +69,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.util.{NextIterator, Utils}
import org.apache.spark.util.ArrayImplicits._
-import org.apache.spark.util.Utils
private[sql] object Dataset {
val curId = new java.util.concurrent.atomic.AtomicLong()
@@ -929,7 +929,18 @@ class Dataset[T] private[sql](
/** @inheritdoc */
@scala.annotation.varargs
def groupBy(cols: Column*): RelationalGroupedDataset = {
- RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType)
+ // Replace top-level integer literals in grouping expressions with ordinals, if
+ // `groupByOrdinal` is enabled.
+ val groupingExpressionsWithOrdinals = cols.map { col => col.expr match {
+ case Literal(value: Int, IntegerType) if sparkSession.sessionState.conf.groupByOrdinal =>
+ UnresolvedOrdinal(value)
+ case other => other
+ }}
+ RelationalGroupedDataset(
+ df = toDF(),
+ groupingExprs = groupingExpressionsWithOrdinals,
+ groupType = RelationalGroupedDataset.GroupByType
+ )
}
/** @inheritdoc */
@@ -1062,16 +1073,6 @@ class Dataset[T] private[sql](
)
}
- /** @inheritdoc */
- def scalar(): Column = {
- Column(ExpressionColumnNode(ScalarSubqueryExpr(logicalPlan)))
- }
-
- /** @inheritdoc */
- def exists(): Column = {
- Column(ExpressionColumnNode(Exists(logicalPlan)))
- }
-
/** @inheritdoc */
@scala.annotation.varargs
def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withSameTypedPlan {
@@ -1573,7 +1574,7 @@ class Dataset[T] private[sql](
sparkSession.sessionState.executePlan(deserialized)
}
- private lazy val materializedRdd: RDD[T] = {
+ private[sql] lazy val materializedRdd: RDD[T] = {
val objectType = exprEnc.deserializer.dataType
rddQueryExecution.toRdd.mapPartitions { rows =>
rows.map(_.get(0, objectType).asInstanceOf[T])
@@ -1677,21 +1678,19 @@ class Dataset[T] private[sql](
val gen = new JacksonGenerator(rowSchema, writer,
new JSONOptions(Map.empty[String, String], sessionLocalTimeZone))
- new Iterator[String] {
+ new NextIterator[String] {
private val toRow = exprEnc.createSerializer()
- override def hasNext: Boolean = iter.hasNext
- override def next(): String = {
+ override def close(): Unit = { gen.close() }
+ override def getNext(): String = {
+ if (!iter.hasNext) {
+ finished = true
+ return ""
+ }
+ writer.reset()
gen.write(toRow(iter.next()))
gen.flush()
- val json = writer.toString
- if (hasNext) {
- writer.reset()
- } else {
- gen.close()
- }
-
- json
+ writer.toString
}
}
} (Encoders.STRING)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
index 082292145e858..0fa6e91e21459 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
@@ -23,7 +23,7 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql
import org.apache.spark.sql.{AnalysisException, Column, Encoder}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedOrdinal}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -36,7 +36,7 @@ import org.apache.spark.sql.classic.TypedAggUtils.withInputType
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.{NumericType, StructType}
+import org.apache.spark.sql.types.{IntegerType, NumericType, StructType}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -67,7 +67,13 @@ class RelationalGroupedDataset protected[sql](
@scala.annotation.nowarn("cat=deprecation")
val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) {
- groupingExprs match {
+ // We need to unwrap ordinals from grouping expressions in order to add grouping columns to
+ // aggregate expressions.
+ val groupingExpressionsWithUnwrappedOrdinals = groupingExprs.map {
+ case UnresolvedOrdinal(value) => Literal(value, IntegerType)
+ case other => other
+ }
+ groupingExpressionsWithUnwrappedOrdinals match {
// call `toList` because `Stream` and `LazyList` can't serialize in scala 2.13
case s: LazyList[Expression] => s.toList ++ aggExprs
case s: Stream[Expression] => s.toList ++ aggExprs
@@ -461,6 +467,32 @@ class RelationalGroupedDataset protected[sql](
Dataset.ofRows(df.sparkSession, plan)
}
+ /**
+ * Applies a grouped python user-defined function to each group of data.
+ * The user-defined function defines a transformation: iterator of `Row` -> iterator of `Row`.
+ * For each group, all elements in the group are passed as an iterator of `Row` along with
+ * corresponding state, and the results for all groups are combined into a new [[DataFrame]].
+ *
+ * This function uses Apache Arrow as serialization format between Java executors and Python
+ * workers.
+ */
+ private[sql] def transformWithStateInPySpark(
+ func: Column,
+ outputStructType: StructType,
+ outputModeStr: String,
+ timeModeStr: String,
+ initialState: RelationalGroupedDataset,
+ eventTimeColumnName: String): DataFrame = {
+ _transformWithStateInPySpark(
+ func,
+ outputStructType,
+ outputModeStr,
+ timeModeStr,
+ initialState,
+ eventTimeColumnName,
+ TransformWithStateInPySpark.UserFacingDataType.PYTHON_ROW)
+ }
+
/**
* Applies a grouped vectorized python user-defined function to each group of data.
* The user-defined function defines a transformation: iterator of `pandas.DataFrame` ->
@@ -479,6 +511,24 @@ class RelationalGroupedDataset protected[sql](
timeModeStr: String,
initialState: RelationalGroupedDataset,
eventTimeColumnName: String): DataFrame = {
+ _transformWithStateInPySpark(
+ func,
+ outputStructType,
+ outputModeStr,
+ timeModeStr,
+ initialState,
+ eventTimeColumnName,
+ TransformWithStateInPySpark.UserFacingDataType.PANDAS)
+ }
+
+ private def _transformWithStateInPySpark(
+ func: Column,
+ outputStructType: StructType,
+ outputModeStr: String,
+ timeModeStr: String,
+ initialState: RelationalGroupedDataset,
+ eventTimeColumnName: String,
+ userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value): DataFrame = {
def exprToAttr(expr: Seq[Expression]): Seq[Attribute] = {
expr.map {
case ne: NamedExpression => ne
@@ -498,12 +548,13 @@ class RelationalGroupedDataset protected[sql](
Project(groupingAttrs ++ leftChild.output, leftChild)).analyzed
val plan: LogicalPlan = if (initialState == null) {
- TransformWithStateInPandas(
+ TransformWithStateInPySpark(
func.expr,
groupingAttrs.length,
outputAttrs,
outputMode,
timeMode,
+ userFacingDataType,
child = left,
hasInitialState = false,
/* The followings are dummy variables because hasInitialState is false */
@@ -519,12 +570,13 @@ class RelationalGroupedDataset protected[sql](
val right = initialState.df.sparkSession.sessionState.executePlan(
Project(initGroupingAttrs ++ rightChild.output, rightChild)).analyzed
- TransformWithStateInPandas(
+ TransformWithStateInPySpark(
func.expr,
groupingAttributesLen = groupingAttrs.length,
outputAttrs,
outputMode,
timeMode,
+ userFacingDataType,
child = left,
hasInitialState = true,
initialState = right,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala
index 6ce6f06de113d..6d4a3ecd36037 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala
@@ -53,7 +53,7 @@ class StreamingQueryManager private[sql] (
with Logging {
private[sql] val stateStoreCoordinator =
- StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env)
+ StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env, sqlConf)
private val listenerBus =
new StreamingQueryListenerBus(Some(sparkSession.sparkContext.listenerBus))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala
index 5766535ac5dac..5aec32c572dae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala
@@ -28,11 +28,12 @@ import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF, TypedAggregateExpression}
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator}
-import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SQLConf, SqlExpression, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame}
+import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SQLConf, SqlExpression, SubqueryExpression, SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame}
import org.apache.spark.sql.types.{DataType, NullType}
/**
@@ -192,6 +193,17 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres
case l: LazyExpression =>
analysis.LazyExpression(apply(l.child))
+ case SubqueryExpression(ds, subqueryType, _) =>
+ subqueryType match {
+ case SubqueryType.SCALAR =>
+ expressions.ScalarSubquery(ds.logicalPlan)
+ case SubqueryType.EXISTS =>
+ expressions.Exists(ds.logicalPlan)
+ case SubqueryType.IN(values) =>
+ expressions.InSubquery(
+ values.map(value => apply(value)), expressions.ListQuery(ds.logicalPlan))
+ }
+
case node =>
throw SparkException.internalError("Unsupported ColumnNode: " + node)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 8fe7565c902a5..ef2b5c1e19cd5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.internal.{LogEntry, Logging, MDC}
+import org.apache.spark.internal.{Logging, MDC, MessageWithContext}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression}
@@ -485,14 +485,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
}
object CacheManager extends Logging {
- def logCacheOperation(f: => LogEntry): Unit = {
- SQLConf.get.dataframeCacheLogLevel match {
- case "TRACE" => logTrace(f)
- case "DEBUG" => logDebug(f)
- case "INFO" => logInfo(f)
- case "WARN" => logWarning(f)
- case "ERROR" => logError(f)
- case _ => logTrace(f)
- }
+ def logCacheOperation(f: => MessageWithContext): Unit = {
+ logBasedOnLevel(SQLConf.get.dataframeCacheLogLevel)(f)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
index a67648f24b4c2..4c9ae155ec17a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -533,6 +533,10 @@ case class ApplyColumnarRulesAndInsertTransitions(
case write: DataWritingCommandExec
if write.cmd.isInstanceOf[V1WriteCommand] && conf.plannedWriteEnabled =>
write.child.supportsColumnar
+ // If it is not required to output columnar (`outputsColumnar` is false), and the plan
+ // supports row-based and columnar, we don't need to output row-based data on its children
+ // nodes. So we set `outputsColumnar` to true.
+ case _ if plan.supportsColumnar && plan.supportsRowBased => true
case _ =>
false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala
index 0594ad2676e05..2e878c21dc7a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala
@@ -297,14 +297,8 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
/**
* Generate detailed field string with different format based on type of input value
*/
- // TODO(nemanja.petrovic@databricks.com) Delete method as it is duplicated in QueryPlan.scala.
- def generateFieldString(fieldName: String, values: Any): String = values match {
- case iter: Iterable[_] if (iter.size == 0) => s"${fieldName}: []"
- case iter: Iterable[_] => s"${fieldName} [${iter.size}]: ${iter.mkString("[", ", ", "]")}"
- case str: String if (str == null || str.isEmpty) => s"${fieldName}: None"
- case str: String => s"${fieldName}: ${str}"
- case _ => throw new IllegalArgumentException(s"Unsupported type for argument values: $values")
- }
+ def generateFieldString(fieldName: String, values: Any): String =
+ QueryPlan.generateFieldString(fieldName, values)
/**
* Given a input plan, returns an array of tuples comprising of :
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
index 28c2ec4b5b7a5..21cf70dab59f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
@@ -23,12 +23,13 @@ import java.time._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.ToStringBase
-import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, FractionTimeFormatter, TimeFormatter, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.HIVE_STYLE
import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, periodToMonths, toDayTimeIntervalString, toYearMonthIntervalString}
import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand, ShowViewsCommand}
import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTablesExec}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
import org.apache.spark.util.ArrayImplicits._
@@ -37,13 +38,14 @@ import org.apache.spark.util.ArrayImplicits._
* Runs a query returning the result in Hive compatible form.
*/
object HiveResult extends SQLConfHelper {
- case class TimeFormatters(date: DateFormatter, timestamp: TimestampFormatter)
+ case class TimeFormatters(date: DateFormatter, time: TimeFormatter, timestamp: TimestampFormatter)
def getTimeFormatters: TimeFormatters = {
val dateFormatter = DateFormatter()
+ val timeFormatter = new FractionTimeFormatter()
val timestampFormatter = TimestampFormatter.getFractionFormatter(
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
- TimeFormatters(dateFormatter, timestampFormatter)
+ TimeFormatters(dateFormatter, timeFormatter, timestampFormatter)
}
type BinaryFormatter = Array[Byte] => String
@@ -51,7 +53,7 @@ object HiveResult extends SQLConfHelper {
def getBinaryFormatter: BinaryFormatter = {
if (conf.getConf(SQLConf.BINARY_OUTPUT_STYLE).isEmpty) {
// Keep the legacy behavior for compatibility.
- conf.setConf(SQLConf.BINARY_OUTPUT_STYLE, Some("UTF-8"))
+ conf.setConf(SQLConf.BINARY_OUTPUT_STYLE, Some(BinaryOutputStyle.UTF8))
}
ToStringBase.getBinaryFormatter(_).toString
}
@@ -113,6 +115,7 @@ object HiveResult extends SQLConfHelper {
case (b, BooleanType) => b.toString
case (d: Date, DateType) => formatters.date.format(d)
case (ld: LocalDate, DateType) => formatters.date.format(ld)
+ case (lt: LocalTime, _: TimeType) => formatters.time.format(lt)
case (t: Timestamp, TimestampType) => formatters.timestamp.format(t)
case (i: Instant, TimestampType) => formatters.timestamp.format(i)
case (l: LocalDateTime, TimestampNTZType) => formatters.timestamp.format(l)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 87cafa58d5fa6..a43c1cc0177db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -209,7 +209,7 @@ class QueryExecution(
executePhase(QueryPlanningTracker.PLANNING) {
// Clone the logical plan here, in case the planner rules change the states of the logical
// plan.
- QueryExecution.createSparkPlan(sparkSession, planner, optimizedPlan.clone())
+ QueryExecution.createSparkPlan(planner, optimizedPlan.clone())
}
}
@@ -574,7 +574,6 @@ object QueryExecution {
* Note that the returned physical plan still needs to be prepared for execution.
*/
def createSparkPlan(
- sparkSession: SparkSession,
planner: SparkPlanner,
plan: LogicalPlan): SparkPlan = {
// TODO: We use next(), i.e. take the first plan returned by the planner, here for now,
@@ -594,7 +593,7 @@ object QueryExecution {
* [[SparkPlan]] for execution.
*/
def prepareExecutedPlan(spark: SparkSession, plan: LogicalPlan): SparkPlan = {
- val sparkPlan = createSparkPlan(spark, spark.sessionState.planner, plan.clone())
+ val sparkPlan = createSparkPlan(spark.sessionState.planner, plan.clone())
prepareExecutedPlan(spark, sparkPlan)
}
@@ -603,11 +602,11 @@ object QueryExecution {
* This method is only called by [[PlanAdaptiveDynamicPruningFilters]].
*/
def prepareExecutedPlan(
- session: SparkSession,
plan: LogicalPlan,
context: AdaptiveExecutionContext): SparkPlan = {
- val sparkPlan = createSparkPlan(session, session.sessionState.planner, plan.clone())
- val preparationRules = preparations(session, Option(InsertAdaptiveSparkPlan(context)), true)
+ val sparkPlan = createSparkPlan(context.session.sessionState.planner, plan.clone())
+ val preparationRules =
+ preparations(context.session, Option(InsertAdaptiveSparkPlan(context)), true)
prepareForExecution(preparationRules, sparkPlan.clone())
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index a51870cfd7fdd..60bde20fe235c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -87,7 +87,7 @@ class SparkOptimizer(
ColumnPruning,
LimitPushDown,
PushPredicateThroughNonJoin,
- PushProjectionThroughLimit,
+ PushProjectionThroughLimitAndOffset,
RemoveNoopOperators),
Batch("Infer window group limit", Once,
InferWindowGroupLimit,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 615c8746a3e52..4410fe50912f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{EmptyRelation, LogicalPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.LogicalQueryStage
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@@ -56,6 +56,7 @@ private[execution] object SparkPlanInfo {
private def fromLogicalPlan(plan: LogicalPlan): SparkPlanInfo = {
val childrenInfo = plan match {
case LogicalQueryStage(_, physical) => Seq(fromSparkPlan(physical))
+ case EmptyRelation(logical) => Seq(fromLogicalPlan(logical))
case _ => (plan.children ++ plan.subqueries).map(fromLogicalPlan)
}
new SparkPlanInfo(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 8859b7b421b3c..e797ba0392bee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -27,7 +27,7 @@ import org.antlr.v4.runtime.tree.TerminalNode
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, PersistedView, PlanWithUnresolvedIdentifier, SchemaEvolution, SchemaTypeEvolution, UnresolvedAttribute, UnresolvedFunctionName, UnresolvedIdentifier, UnresolvedNamespace}
+import org.apache.spark.sql.catalyst.analysis.{CurrentNamespace, GlobalTempView, LocalTempView, PersistedView, PlanWithUnresolvedIdentifier, SchemaEvolution, SchemaTypeEvolution, UnresolvedAttribute, UnresolvedFunctionName, UnresolvedIdentifier, UnresolvedNamespace}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.parser._
@@ -365,7 +365,7 @@ class SparkSqlAstBuilder extends AstBuilder {
visitCreateTableClauses(ctx.createTableClauses())
val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText).getOrElse(
throw QueryParsingErrors.createTempTableNotSpecifyProviderError(ctx))
- val schema = Option(ctx.colDefinitionList()).map(createSchema)
+ val schema = Option(ctx.tableElementList()).map(createSchema)
logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " +
"CREATE TEMPORARY VIEW ... USING ... instead")
@@ -689,6 +689,20 @@ class SparkSqlAstBuilder extends AstBuilder {
throw QueryParsingErrors.createFuncWithBothIfNotExistsAndReplaceError(ctx)
}
+ // Reject invalid options
+ for {
+ parameters <- Option(ctx.parameters)
+ colDefinition <- parameters.colDefinition().asScala
+ option <- colDefinition.colDefinitionOption().asScala
+ } {
+ if (option.generationExpression() != null) {
+ throw QueryParsingErrors.createFuncWithGeneratedColumnsError(ctx.parameters)
+ }
+ if (option.columnConstraintDefinition() != null) {
+ throw QueryParsingErrors.createFuncWithConstraintError(ctx.parameters)
+ }
+ }
+
val inputParamText = Option(ctx.parameters).map(source)
val returnTypeText: String =
if (ctx.RETURNS != null &&
@@ -1179,4 +1193,26 @@ class SparkSqlAstBuilder extends AstBuilder {
}
}
}
+
+ override def visitShowProcedures(ctx: ShowProceduresContext): LogicalPlan = withOrigin(ctx) {
+ val ns = if (ctx.identifierReference != null) {
+ withIdentClause(ctx.identifierReference, UnresolvedNamespace(_))
+ } else {
+ CurrentNamespace
+ }
+ ShowProceduresCommand(ns)
+ }
+
+ override def visitShowNamespaces(ctx: ShowNamespacesContext): LogicalPlan = withOrigin(ctx) {
+ val multiPart = Option(ctx.multipartIdentifier).map(visitMultipartIdentifier)
+ ShowNamespacesCommand(
+ UnresolvedNamespace(multiPart.getOrElse(Seq.empty[String])),
+ Option(ctx.pattern).map(x => string(visitStringLit(x))))
+ }
+
+ override def visitDescribeProcedure(
+ ctx: DescribeProcedureContext): LogicalPlan = withOrigin(ctx) {
+ withIdentClause(ctx.identifierReference(), procIdentifier =>
+ DescribeProcedureCommand(UnresolvedIdentifier(procIdentifier)))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 89f86c347568d..5be6998f08d27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.{WriteFiles, WriteFilesExec}
import org.apache.spark.sql.execution.exchange.{REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec}
import org.apache.spark.sql.execution.python._
-import org.apache.spark.sql.execution.python.streaming.{FlatMapGroupsInPandasWithStateExec, TransformWithStateInPandasExec}
+import org.apache.spark.sql.execution.python.streaming.{FlatMapGroupsInPandasWithStateExec, TransformWithStateInPySparkExec}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
import org.apache.spark.sql.internal.SQLConf
@@ -792,20 +792,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Strategy to convert [[TransformWithStateInPandas]] logical operator to physical operator
+ * Strategy to convert [[TransformWithStateInPySpark]] logical operator to physical operator
* in streaming plans.
*/
- object TransformWithStateInPandasStrategy extends Strategy {
+ object TransformWithStateInPySparkStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case t @ TransformWithStateInPandas(
- func, _, outputAttrs, outputMode, timeMode, child,
+ case t @ TransformWithStateInPySpark(
+ func, _, outputAttrs, outputMode, timeMode, userFacingDataType, child,
hasInitialState, initialState, _, initialStateSchema) =>
- val execPlan = TransformWithStateInPandasExec(
+ val execPlan = TransformWithStateInPySparkExec(
func, t.leftAttributes, outputAttrs, outputMode, timeMode,
stateInfo = None,
batchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
+ userFacingDataType,
planLater(child),
isStreaming = true,
hasInitialState,
@@ -976,12 +977,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
keyEncoder, outputObjAttr, planLater(child), hasInitialState,
initialStateGroupingAttrs, initialStateDataAttrs,
initialStateDeserializer, planLater(initialState)) :: Nil
- case t @ TransformWithStateInPandas(
- func, _, outputAttrs, outputMode, timeMode, child,
+ case t @ TransformWithStateInPySpark(
+ func, _, outputAttrs, outputMode, timeMode, userFacingDataType, child,
hasInitialState, initialState, _, initialStateSchema) =>
- TransformWithStateInPandasExec.generateSparkPlanForBatchQueries(func,
- t.leftAttributes, outputAttrs, outputMode, timeMode, planLater(child), hasInitialState,
- planLater(initialState), t.rightAttributes, initialStateSchema) :: Nil
+ TransformWithStateInPySparkExec.generateSparkPlanForBatchQueries(func,
+ t.leftAttributes, outputAttrs, outputMode, timeMode, userFacingDataType,
+ planLater(child), hasInitialState, planLater(initialState), t.rightAttributes,
+ initialStateSchema) :: Nil
case _: FlatMapGroupsInPandasWithState =>
// TODO(SPARK-40443): support applyInPandasWithState in batch query
@@ -1031,6 +1033,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
GlobalLimitExec(child = planLater(child), offset = offset) :: Nil
case union: logical.Union =>
execution.UnionExec(union.children.map(planLater)) :: Nil
+ case u @ logical.UnionLoop(id, anchor, recursion, limit) =>
+ execution.UnionLoopExec(id, anchor, recursion, u.output, limit) :: Nil
case g @ logical.Generate(generator, _, outer, _, _, child) =>
execution.GenerateExec(
generator, g.requiredChildOutput, outer,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
new file mode 100644
index 0000000000000..85c7a57467b5d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, LogicalPlan, Union, UnionLoopRef}
+import org.apache.spark.sql.classic.Dataset
+import org.apache.spark.sql.execution.LogicalRDD.rewriteStatsAndConstraints
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
+
+
+/**
+ * The physical node for recursion. Currently only UNION ALL case is supported.
+ * For the details about the execution, look at the comment above doExecute function.
+ *
+ * A simple recursive query:
+ * {{{
+ * WITH RECURSIVE t(n) AS (
+ * SELECT 1
+ * UNION ALL
+ * SELECT n+1 FROM t WHERE n < 5)
+ * SELECT * FROM t;
+ * }}}
+ * Corresponding logical plan for the recursive query above:
+ * {{{
+ * WithCTE
+ * :- CTERelationDef 0, false
+ * : +- SubqueryAlias t
+ * : +- Project [1#0 AS n#3]
+ * : +- UnionLoop 0
+ * : :- Project [1 AS 1#0]
+ * : : +- OneRowRelation
+ * : +- Project [(n#1 + 1) AS (n + 1)#2]
+ * : +- Filter (n#1 < 5)
+ * : +- SubqueryAlias t
+ * : +- Project [1#0 AS n#1]
+ * : +- UnionLoopRef 0, [1#0], false
+ * +- Project [n#3]
+ * +- SubqueryAlias t
+ * +- CTERelationRef 0, true, [n#3], false, false
+ * }}}
+ *
+ * @param loopId This is id of the CTERelationDef containing the recursive query. Its value is
+ * first passed down to UnionLoop when creating it, and then to UnionLoopExec in
+ * SparkStrategies.
+ * @param anchor The logical plan of the initial element of the loop.
+ * @param recursion The logical plan that describes the recursion with an [[UnionLoopRef]] node.
+ * CTERelationRef, which is marked as recursive, gets substituted with
+ * [[UnionLoopRef]] in ResolveWithCTE.
+ * Both anchor and recursion are marked with @transient annotation, so that they
+ * are not serialized.
+ * @param output The output attributes of this loop.
+ * @param limit If defined, the total number of rows output by this operator will be bounded by
+ * limit.
+ * Its value is pushed down to UnionLoop in Optimizer in case LocalLimit node is
+ * present in the logical plan and then transferred to UnionLoopExec in
+ * SparkStrategies.
+ * Note here: limit can be applied in the main query calling the recursive CTE, and not
+ * inside the recursive term of recursive CTE.
+ */
+case class UnionLoopExec(
+ loopId: Long,
+ @transient anchor: LogicalPlan,
+ @transient recursion: LogicalPlan,
+ override val output: Seq[Attribute],
+ limit: Option[Int] = None) extends LeafExecNode {
+
+ override def innerChildren: Seq[QueryPlan[_]] = Seq(anchor, recursion)
+
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ "numIterations" -> SQLMetrics.createMetric(sparkContext, "number of recursive iterations"),
+ "numAnchorOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of anchor output rows"))
+
+ /**
+ * This function executes the plan (optionally with appended limit node) and caches the result,
+ * with the caching mode specified in config.
+ */
+ private def executeAndCacheAndCount(plan: LogicalPlan, currentLimit: Int) = {
+ // In case limit is defined, we create a (local) limit node above the plan and execute
+ // the newly created plan.
+ val planWithLimit = if (limit.isDefined) {
+ LocalLimit(Literal(currentLimit), plan)
+ } else {
+ plan
+ }
+ val df = Dataset.ofRows(session, planWithLimit)
+ val materializedDF = df.repartition()
+ val count = materializedDF.queryExecution.toRdd.count()
+ (materializedDF, count)
+ }
+
+ /**
+ * In the first iteration, anchor term is executed.
+ * Then, in each following iteration, the UnionLoopRef node is substituted with the plan from the
+ * previous iteration, and such plan is executed.
+ * After every iteration, the dataframe is materialized.
+ * The recursion stops when the generated dataframe is empty, or either the limit or
+ * the specified maximum depth from the config is reached.
+ */
+ override protected def doExecute(): RDD[InternalRow] = {
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ val numOutputRows = longMetric("numOutputRows")
+ val numIterations = longMetric("numIterations")
+ val numAnchorOutputRows = longMetric("numAnchorOutputRows")
+ val levelLimit = conf.getConf(SQLConf.CTE_RECURSION_LEVEL_LIMIT)
+ val rowLimit = conf.getConf(SQLConf.CTE_RECURSION_ROW_LIMIT)
+
+ // currentLimit is initialized from the limit argument, and in each step it is decreased by
+ // the number of rows generated in that step.
+ // If limit is not passed down, currentLimit is set to be zero and won't be considered in the
+ // condition of while loop down (limit.isEmpty will be true).
+ var currentLimit = limit.getOrElse(-1)
+
+ val unionChildren = mutable.ArrayBuffer.empty[LogicalRDD]
+
+ var (prevDF, prevCount) = executeAndCacheAndCount(anchor, currentLimit)
+
+ numAnchorOutputRows += prevCount
+
+ var currentLevel = 1
+
+ var currentNumRows = 0
+
+ var limitReached: Boolean = false
+
+ val numPartitions = prevDF.queryExecution.toRdd.partitions.length
+ // Main loop for obtaining the result of the recursive query.
+ while (prevCount > 0 && !limitReached) {
+
+ if (levelLimit != -1 && currentLevel > levelLimit) {
+ throw new SparkException(
+ errorClass = "RECURSION_LEVEL_LIMIT_EXCEEDED",
+ messageParameters = Map("levelLimit" -> levelLimit.toString),
+ cause = null)
+ }
+
+ // Inherit stats and constraints from the dataset of the previous iteration.
+ val prevPlan = LogicalRDD.fromDataset(prevDF.queryExecution.toRdd, prevDF, prevDF.isStreaming)
+ .newInstance()
+ unionChildren += prevPlan
+
+ currentNumRows += prevCount.toInt
+
+ if (limit.isDefined) {
+ currentLimit -= prevCount.toInt
+ if (currentLimit <= 0) {
+ limitReached = true
+ }
+ }
+
+ if (rowLimit != -1 && currentNumRows > rowLimit) {
+ throw new SparkException(
+ errorClass = "RECURSION_ROW_LIMIT_EXCEEDED",
+ messageParameters = Map("rowLimit" -> rowLimit.toString),
+ cause = null)
+ }
+
+ // Update metrics
+ numOutputRows += prevCount
+ numIterations += 1
+
+ if (!limitReached) {
+ // the current plan is created by substituting UnionLoopRef node with the project node of
+ // the previous plan.
+ // This way we support only UNION ALL case. Additional case should be added for UNION case.
+ // One way of supporting UNION case can be seen at SPARK-24497 PR from Peter Toth.
+ val newRecursion = recursion.transform {
+ case r: UnionLoopRef if r.loopId == loopId =>
+ val logicalPlan = prevDF.logicalPlan
+ val optimizedPlan = prevDF.queryExecution.optimizedPlan
+ val (stats, constraints) = rewriteStatsAndConstraints(logicalPlan, optimizedPlan)
+ prevPlan.copy(output = r.output)(prevDF.sparkSession, stats, constraints)
+ }
+
+ val (df, count) = executeAndCacheAndCount(newRecursion, currentLimit)
+ prevDF = df
+ prevCount = count
+
+ currentLevel += 1
+ }
+ }
+
+ SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
+
+ if (unionChildren.isEmpty) {
+ new EmptyRDD[InternalRow](sparkContext)
+ } else {
+ val df = {
+ if (unionChildren.length == 1) {
+ Dataset.ofRows(session, unionChildren.head)
+ } else {
+ Dataset.ofRows(session, Union(unionChildren.toSeq))
+ }
+ }
+ val coalescedDF = df.coalesce(numPartitions)
+ coalescedDF.queryExecution.toRdd
+ }
+ }
+
+ override def doCanonicalize(): SparkPlan =
+ super.doCanonicalize().asInstanceOf[UnionLoopExec]
+ .copy(anchor = anchor.canonicalized, recursion = recursion.canonicalized)
+
+ override def verboseStringWithOperatorId(): String = {
+ s"""
+ |$formattedNodeName
+ |Loop id: $loopId
+ |${QueryPlan.generateFieldString("Output", output)}
+ |Limit: $limit
+ |""".stripMargin
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 920f61574770d..1ee467ef3554b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -974,8 +974,7 @@ case class CollapseCodegenStages(
}
def apply(plan: SparkPlan): SparkPlan = {
- if (conf.wholeStageEnabled && CodegenObjectFactoryMode.withName(conf.codegenFactoryMode)
- != CodegenObjectFactoryMode.NO_CODEGEN) {
+ if (conf.wholeStageEnabled && conf.codegenFactoryMode != CodegenObjectFactoryMode.NO_CODEGEN) {
insertWholeStageCodegen(plan)
} else {
plan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
index 12e8d0e2c6089..e8b70f94a7692 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
@@ -200,7 +200,15 @@ case class AQEShuffleReadExec private(
val numCoalescedPartitionsMetric = metrics("numCoalescedPartitions")
val x = partitionSpecs.count(isCoalescedSpec)
numCoalescedPartitionsMetric.set(x)
- driverAccumUpdates += numCoalescedPartitionsMetric.id -> x
+ val numEmptyPartitionsMetric = metrics("numEmptyPartitions")
+ val y = child match {
+ case s: ShuffleQueryStageExec =>
+ s.mapStats.map(stats => stats.bytesByPartitionId.count(_ == 0)).getOrElse(0)
+ case _ => 0
+ }
+ numEmptyPartitionsMetric.set(y)
+ driverAccumUpdates ++= Seq(numCoalescedPartitionsMetric.id -> x,
+ numEmptyPartitionsMetric.id -> y)
}
partitionDataSizes.foreach { dataSizes =>
@@ -236,7 +244,9 @@ case class AQEShuffleReadExec private(
} ++ {
if (hasCoalescedPartition) {
Map("numCoalescedPartitions" ->
- SQLMetrics.createMetric(sparkContext, "number of coalesced partitions"))
+ SQLMetrics.createMetric(sparkContext, "number of coalesced partitions"),
+ "numEmptyPartitions" ->
+ SQLMetrics.createMetric(sparkContext, "number of empty partitions"))
} else {
Map.empty
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 07d215f8a186f..996e01a0ea936 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -77,13 +77,7 @@ case class AdaptiveSparkPlanExec(
@transient private val lock = new Object()
@transient private val logOnLevel: ( => MessageWithContext) => Unit =
- conf.adaptiveExecutionLogLevel match {
- case "TRACE" => logTrace(_)
- case "INFO" => logInfo(_)
- case "WARN" => logWarning(_)
- case "ERROR" => logError(_)
- case _ => logDebug(_)
- }
+ logBasedOnLevel(conf.adaptiveExecutionLogLevel)
@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
index 73fc9b1fe4e2c..2855f902a8509 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
@@ -153,7 +153,7 @@ case class InsertAdaptiveSparkPlan(
// Apply the same instance of this rule to sub-queries so that sub-queries all share the
// same `stageCache` for Exchange reuse.
this.applyInternal(
- QueryExecution.createSparkPlan(adaptiveExecutionContext.session,
+ QueryExecution.createSparkPlan(
adaptiveExecutionContext.session.sessionState.planner, plan.clone()), true)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
index 77c180b18aee0..751cfe5b7bb6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
@@ -74,9 +74,7 @@ case class PlanAdaptiveDynamicPruningFilters(
val aliases = indices.map(idx => Alias(buildKeys(idx), buildKeys(idx).toString)())
val aggregate = Aggregate(aliases, aliases, buildPlan)
- val session = adaptivePlan.context.session
- val sparkPlan = QueryExecution.prepareExecutedPlan(
- session, aggregate, adaptivePlan.context)
+ val sparkPlan = QueryExecution.prepareExecutedPlan(aggregate, adaptivePlan.context)
assert(sparkPlan.isInstanceOf[AdaptiveSparkPlanExec])
val newAdaptivePlan = sparkPlan.asInstanceOf[AdaptiveSparkPlanExec]
val values = SubqueryExec(name, newAdaptivePlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 367d4cfafb485..de1b83c16ac97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -94,7 +94,8 @@ object AggUtils {
child = child)
} else {
val objectHashEnabled = child.conf.useObjectHashAggregation
- val useObjectHash = Aggregate.supportsObjectHashAggregate(aggregateExpressions)
+ val useObjectHash = Aggregate.supportsObjectHashAggregate(
+ aggregateExpressions, groupingExpressions)
if (forceObjHashAggregate || (objectHashEnabled && useObjectHash && !forceSortAggregate)) {
ObjectHashAggregateExec(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 469f42dcc0afe..24528b6f4da15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -682,7 +682,7 @@ case class HashAggregateExec(
| $unsafeRowKeys, $unsafeRowKeyHash);
| if ($unsafeRowBuffer == null) {
| // failed to allocate the first page
- | throw new $oomeClassName("_LEGACY_ERROR_TEMP_3302", new java.util.HashMap());
+ | throw new $oomeClassName("AGGREGATE_OUT_OF_MEMORY", new java.util.HashMap());
| }
|}
""".stripMargin
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 2f1cda9d0f9be..073e5929025b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -212,7 +212,7 @@ class TungstenAggregationIterator(
if (buffer == null) {
// failed to allocate the first page
// scalastyle:off throwerror
- throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3302", new util.HashMap())
+ throw new SparkOutOfMemoryError("AGGREGATE_OUT_OF_MEMORY", new util.HashMap())
// scalastyle:on throwerror
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 0f280d236203f..8f704cec7e892 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.columnar
+import com.esotericsoftware.kryo.{DefaultSerializer, Kryo, Serializer => KryoSerializer}
+import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import org.apache.commons.lang3.StringUtils
import org.apache.spark.{SparkException, TaskContext}
@@ -30,11 +32,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Sta
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer}
-import org.apache.spark.sql.execution.{ColumnarToRowTransition, InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
-import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructType, UserDefinedType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{LongAccumulator, Utils}
@@ -47,9 +49,56 @@ import org.apache.spark.util.ArrayImplicits._
* @param buffers The buffers for serialized columns
* @param stats The stat of columns
*/
-case class DefaultCachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
+@DefaultSerializer(classOf[DefaultCachedBatchKryoSerializer])
+case class DefaultCachedBatch(
+ numRows: Int,
+ buffers: Array[Array[Byte]],
+ stats: InternalRow)
extends SimpleMetricsCachedBatch
+class DefaultCachedBatchKryoSerializer extends KryoSerializer[DefaultCachedBatch] {
+ override def write(kryo: Kryo, output: KryoOutput, batch: DefaultCachedBatch): Unit = {
+ output.writeInt(batch.numRows)
+ SparkException.require(batch.buffers != null, "INVALID_KRYO_SERIALIZER_NO_DATA",
+ Map("obj" -> "DefaultCachedBatch.buffers",
+ "serdeOp" -> "serialize",
+ "serdeClass" -> this.getClass.getName))
+ output.writeInt(batch.buffers.length + 1) // +1 to distinguish Kryo.NULL
+ for (i <- batch.buffers.indices) {
+ val buffer = batch.buffers(i)
+ SparkException.require(buffer != null, "INVALID_KRYO_SERIALIZER_NO_DATA",
+ Map("obj" -> s"DefaultCachedBatch.buffers($i)",
+ "serdeOp" -> "serialize",
+ "serdeClass" -> this.getClass.getName))
+ output.writeInt(buffer.length + 1) // +1 to distinguish Kryo.NULL
+ output.writeBytes(buffer)
+ }
+ kryo.writeClassAndObject(output, batch.stats)
+ }
+
+ override def read(
+ kryo: Kryo, input: KryoInput, cls: Class[DefaultCachedBatch]): DefaultCachedBatch = {
+ val numRows = input.readInt()
+ val length = input.readInt()
+ SparkException.require(length != Kryo.NULL, "INVALID_KRYO_SERIALIZER_NO_DATA",
+ Map("obj" -> "DefaultCachedBatch.buffers",
+ "serdeOp" -> "deserialize",
+ "serdeClass" -> this.getClass.getName))
+ val buffers = 0.until(length - 1).map { i => // -1 to restore
+ val subLength = input.readInt()
+ SparkException.require(subLength != Kryo.NULL, "INVALID_KRYO_SERIALIZER_NO_DATA",
+ Map("obj" -> s"DefaultCachedBatch.buffers($i)",
+ "serdeOp" -> "deserialize",
+ "serdeClass" -> this.getClass.getName))
+ val innerArray = new Array[Byte](subLength - 1) // -1 to restore
+ input.readBytes(innerArray)
+ innerArray
+ }.toArray
+ val stats = kryo.readClassAndObject(input).asInstanceOf[InternalRow]
+ DefaultCachedBatch(numRows, buffers, stats)
+ }
+}
+
/**
* The default implementation of CachedBatchSerializer.
*/
@@ -467,4 +516,7 @@ case class InMemoryRelation(
override def simpleString(maxFields: Int): String =
s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}"
+
+ override def stringArgs: Iterator[Any] =
+ Iterator(output, cacheBuilder.storageLevel, outputOrdering)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
index fe4e6f121f57b..09b2c86970754 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{Analyzer, SQLFunctionNode, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation}
-import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, SQLFunction, UserDefinedFunctionErrors}
+import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, SQLFunction, UserDefinedFunction, UserDefinedFunctionErrors}
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Generator, LateralSubquery, Literal, ScalarSubquery, SubqueryExpression, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.Inner
@@ -70,7 +70,7 @@ case class CreateSQLFunctionCommand(
val catalog = sparkSession.sessionState.catalog
val conf = sparkSession.sessionState.conf
- val inputParam = inputParamText.map(parser.parseTableSchema)
+ val inputParam = inputParamText.map(UserDefinedFunction.parseRoutineParam(_, parser))
val returnType = parseReturnTypeText(returnTypeText, isTableFunc, parser)
val function = SQLFunction(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeProcedureCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeProcedureCommand.scala
new file mode 100644
index 0000000000000..ef7a538307bf0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeProcedureCommand.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkException, SparkThrowable}
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.connector.catalog.{Identifier, ProcedureCatalog}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+import org.apache.spark.sql.connector.catalog.procedures.UnboundProcedure
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.StringType
+
+/**
+ * A command for users to describe a procedure.
+ * The syntax of using this command in SQL is:
+ * {{{
+ * DESC PROCEDURE procedure_name
+ * }}}
+ */
+case class DescribeProcedureCommand(
+ child: LogicalPlan,
+ override val output: Seq[Attribute] = Seq(
+ AttributeReference("procedure_desc", StringType, nullable = false)()
+ )) extends UnaryRunnableCommand {
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ child match {
+ case ResolvedIdentifier(catalog, ident) =>
+ val procedure = load(catalog.asProcedureCatalog, ident)
+ describeV2Procedure(procedure)
+ case _ =>
+ throw SparkException.internalError(s"Invalid procedure identifier: ${child.getClass}")
+ }
+ }
+
+ private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = {
+ try {
+ catalog.loadProcedure(ident)
+ } catch {
+ case e: Exception if !e.isInstanceOf[SparkThrowable] =>
+ val nameParts = catalog.name +: ident.asMultipartIdentifier
+ throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e)
+ }
+ }
+
+ private def describeV2Procedure(procedure: UnboundProcedure): Seq[Row] = {
+ val buffer = new ArrayBuffer[(String, String)]
+ append(buffer, "Procedure:", procedure.name())
+ append(buffer, "Description:", procedure.description())
+
+ val keys = tabulate(buffer.map(_._1).toSeq)
+ val values = buffer.map(_._2)
+ keys.zip(values).map { case (key, value) => Row(s"$key $value") }
+ }
+
+ private def append(buffer: ArrayBuffer[(String, String)], key: String, value: String): Unit = {
+ buffer += (key -> value)
+ }
+
+ /**
+ * Pad all input strings into the same length using the max string length among all inputs.
+ */
+ private def tabulate(inputs: Seq[String]): Seq[String] = {
+ val maxLen = inputs.map(_.length).max
+ inputs.map { input => input.padTo(maxLen, " ").mkString }
+ }
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = {
+ copy(child = newChild)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala
index 0607a8593fbb8..ed248ccca67a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala
@@ -74,6 +74,7 @@ case class DescribeRelationJsonCommand(
describeIdentifier(v.identifier.toQualifiedNameParts(v.catalog), jsonMap)
describeColsJson(v.metadata.schema, jsonMap)
describeFormattedTableInfoJson(v.metadata, jsonMap)
+ describeViewSqlConfsJson(v.metadata, jsonMap)
case ResolvedTable(catalog, identifier, V1Table(metadata), _) =>
describeIdentifier(identifier.toQualifiedNameParts(catalog), jsonMap)
@@ -98,6 +99,10 @@ case class DescribeRelationJsonCommand(
case _ => throw QueryCompilationErrors.describeAsJsonNotSupportedForV2TablesError()
}
+ // Add default collation if not yet added (addKeyValueToMap only adds unique keys).
+ // Add here to only affect `DESC AS JSON` and not the `DESC TABLE` output.
+ addKeyValueToMap("collation", JString("UTF8_BINARY"), jsonMap)
+
Seq(Row(compact(render(JObject(jsonMap.toList)))))
}
@@ -223,6 +228,12 @@ case class DescribeRelationJsonCommand(
"end_unit" -> JString(getFieldName(dayTimeIntervalType.endField))
)
+ case stringType: StringType =>
+ JObject(
+ "name" -> JString("string"),
+ "collation" -> JString(stringType.collationName)
+ )
+
case _ =>
JObject("name" -> JString(dataType.simpleString))
}
@@ -236,26 +247,31 @@ case class DescribeRelationJsonCommand(
addKeyValueToMap("columns", columnsJson, jsonMap)
}
+ /** Display SQL confs set at time of view creation */
+ private def describeViewSqlConfsJson(
+ table: CatalogTable,
+ jsonMap: mutable.LinkedHashMap[String, JValue]): Unit = {
+ val viewConfigs: Map[String, String] = table.viewSQLConfigs
+ val viewConfigsJson: JValue = JObject(viewConfigs.map { case (key, value) =>
+ key -> JString(value)
+ }.toList)
+ addKeyValueToMap("view_creation_spark_configuration", viewConfigsJson, jsonMap)
+ }
+
private def describeClusteringInfoJson(
table: CatalogTable, jsonMap: mutable.LinkedHashMap[String, JValue]): Unit = {
table.clusterBySpec.foreach { clusterBySpec =>
- val clusteringColumnsJson: JValue = JArray(
- clusterBySpec.columnNames.map { fieldNames =>
- val nestedFieldOpt = table.schema.findNestedField(fieldNames.fieldNames.toIndexedSeq)
- assert(nestedFieldOpt.isDefined,
- "The clustering column " +
- s"${fieldNames.fieldNames.map(quoteIfNeeded).mkString(".")} " +
- s"was not found in the table schema ${table.schema.catalogString}."
- )
- val (path, field) = nestedFieldOpt.get
- JObject(
- "name" -> JString((path :+ field.name).map(quoteIfNeeded).mkString(".")),
- "type" -> jsonType(field.dataType),
- "comment" -> field.getComment().map(JString).getOrElse(JNull)
- )
- }.toList
- )
- addKeyValueToMap("clustering_information", clusteringColumnsJson, jsonMap)
+ val clusteringColumnsJson = JArray(clusterBySpec.columnNames.map { fieldNames =>
+ val nestedFieldOpt = table.schema.findNestedField(fieldNames.fieldNames.toIndexedSeq)
+ assert(nestedFieldOpt.isDefined,
+ "The clustering column " +
+ s"${fieldNames.fieldNames.map(quoteIfNeeded).mkString(".")} " +
+ s"was not found in the table schema ${table.schema.catalogString}."
+ )
+ JString(fieldNames.fieldNames.map(quoteIfNeeded).mkString("."))
+ }.toList)
+
+ addKeyValueToMap("clustering_columns", clusteringColumnsJson, jsonMap)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ShowNamespacesCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ShowNamespacesCommand.scala
new file mode 100644
index 0000000000000..9814d325ff4ff
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ShowNamespacesCommand.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.analysis.ResolvedNamespace
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.StringUtils
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, NamespaceHelper}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StringType
+
+/**
+ * The command for `SHOW NAMESPACES`.
+ */
+case class ShowNamespacesCommand(
+ child: LogicalPlan,
+ pattern: Option[String],
+ override val output: Seq[Attribute] = ShowNamespacesCommand.output)
+ extends UnaryRunnableCommand {
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val ResolvedNamespace(cat, ns, _) = child
+ val nsCatalog = cat.asNamespaceCatalog
+ val namespaces = if (ns.nonEmpty) {
+ nsCatalog.listNamespaces(ns.toArray)
+ } else {
+ nsCatalog.listNamespaces()
+ }
+
+ // The legacy SHOW DATABASES command does not quote the database names.
+ assert(output.length == 1)
+ val namespaceNames = if (output.head.name == "databaseName"
+ && namespaces.forall(_.length == 1)) {
+ namespaces.map(_.head)
+ } else {
+ namespaces.map(_.quoted)
+ }
+
+ namespaceNames
+ .filter{ns => pattern.forall(StringUtils.filterPattern(Seq(ns), _).nonEmpty)}
+ .map(Row(_))
+ .toSeq
+ }
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = {
+ copy(child = newChild)
+ }
+}
+
+object ShowNamespacesCommand {
+ def output: Seq[AttributeReference] = {
+ Seq(
+ if (SQLConf.get.legacyOutputSchema) {
+ AttributeReference("databaseName", StringType, nullable = false)()
+ } else {
+ AttributeReference("namespace", StringType, nullable = false)()
+ }
+ )
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ShowProceduresCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ShowProceduresCommand.scala
new file mode 100644
index 0000000000000..f08d0924d7cd2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ShowProceduresCommand.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.analysis.ResolvedNamespace
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+import org.apache.spark.sql.types.{ArrayType, StringType}
+
+/**
+ * A command for users to get procedures.
+ * If a namespace is not given, the current namespace will be used.
+ * The syntax of using this command in SQL is:
+ * {{{
+ * SHOW PROCEDURES [(IN|FROM) namespace]]
+ * }}}
+ */
+case class ShowProceduresCommand(
+ child: LogicalPlan,
+ override val output: Seq[Attribute] = Seq(
+ AttributeReference("catalog", StringType, nullable = false)(),
+ AttributeReference("namespace", ArrayType(StringType, containsNull = false))(),
+ AttributeReference("schema", StringType)(),
+ AttributeReference("procedure_name", StringType, nullable = false)()
+ )) extends UnaryRunnableCommand {
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ child match {
+ case ResolvedNamespace(catalog, ns, _) =>
+ val procedureCatalog = catalog.asProcedureCatalog
+ val procedures = procedureCatalog.listProcedures(ns.toArray)
+
+ procedures.toSeq.map{ p =>
+ val schema = if (p.namespace() != null && p.namespace().nonEmpty) {
+ p.namespace().last
+ } else {
+ null
+ }
+ Row(catalog.name, p.namespace(), schema, p.name)
+ }
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = {
+ copy(child = newChild)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index dbf98c70504d8..3a45e655fb1a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -137,7 +137,8 @@ case class CreateViewCommand(
originalText,
analyzedPlan,
aliasedPlan,
- referredTempFunctions)
+ referredTempFunctions,
+ collation)
catalog.createTempView(name.table, tableDefinition, overrideIfExists = replace)
} else if (viewType == GlobalTempView) {
val db = sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)
@@ -675,7 +676,7 @@ object ViewHelper extends SQLConfHelper with Logging {
val tempVars = collectTemporaryVariables(child)
tempVars.foreach { nameParts =>
throw QueryCompilationErrors.notAllowedToCreatePermanentViewByReferencingTempVarError(
- name, nameParts.quoted)
+ name.nameParts, nameParts)
}
}
}
@@ -739,8 +740,10 @@ object ViewHelper extends SQLConfHelper with Logging {
originalText: Option[String],
analyzedPlan: LogicalPlan,
aliasedPlan: LogicalPlan,
- referredTempFunctions: Seq[String]): TemporaryViewRelation = {
- val uncache = getRawTempView(name.table).map { r =>
+ referredTempFunctions: Seq[String],
+ collation: Option[String] = None): TemporaryViewRelation = {
+ val rawTempView = getRawTempView(name.table)
+ val uncache = rawTempView.map { r =>
needsToUncache(r, aliasedPlan)
}.getOrElse(false)
val storeAnalyzedPlanForView = session.sessionState.conf.storeAnalyzedPlanForView ||
@@ -754,6 +757,16 @@ object ViewHelper extends SQLConfHelper with Logging {
}
CommandUtils.uncacheTableOrView(session, name)
}
+ // When called from CreateViewCommand, this function determines the collation from the
+ // DEFAULT COLLATION clause in the query or assigns None if unspecified.
+ // When called from AlterViewAsCommand, it retrieves the collation from the view's metadata.
+ val defaultCollation = if (collation.isDefined) {
+ collation
+ } else if (rawTempView.isDefined) {
+ rawTempView.get.tableMeta.collation
+ } else {
+ None
+ }
if (!storeAnalyzedPlanForView) {
TemporaryViewRelation(
prepareTemporaryView(
@@ -762,10 +775,11 @@ object ViewHelper extends SQLConfHelper with Logging {
analyzedPlan,
aliasedPlan.schema,
originalText.get,
- referredTempFunctions))
+ referredTempFunctions,
+ defaultCollation))
} else {
TemporaryViewRelation(
- prepareTemporaryViewStoringAnalyzedPlan(name, aliasedPlan),
+ prepareTemporaryViewStoringAnalyzedPlan(name, aliasedPlan, defaultCollation),
Some(aliasedPlan))
}
}
@@ -795,7 +809,8 @@ object ViewHelper extends SQLConfHelper with Logging {
analyzedPlan: LogicalPlan,
viewSchema: StructType,
originalText: String,
- tempFunctions: Seq[String]): CatalogTable = {
+ tempFunctions: Seq[String],
+ collation: Option[String]): CatalogTable = {
val tempViews = collectTemporaryViews(analyzedPlan)
val tempVariables = collectTemporaryVariables(analyzedPlan)
@@ -812,7 +827,8 @@ object ViewHelper extends SQLConfHelper with Logging {
schema = viewSchema,
viewText = Some(originalText),
createVersion = org.apache.spark.SPARK_VERSION,
- properties = newProperties)
+ properties = newProperties,
+ collation = collation)
}
/**
@@ -821,12 +837,14 @@ object ViewHelper extends SQLConfHelper with Logging {
*/
private def prepareTemporaryViewStoringAnalyzedPlan(
viewName: TableIdentifier,
- analyzedPlan: LogicalPlan): CatalogTable = {
+ analyzedPlan: LogicalPlan,
+ collation: Option[String]): CatalogTable = {
CatalogTable(
identifier = viewName,
tableType = CatalogTableType.VIEW,
storage = CatalogStorageFormat.empty,
schema = analyzedPlan.schema,
- properties = Map((VIEW_STORING_ANALYZED_PLAN, "true")))
+ properties = Map((VIEW_STORING_ANALYZED_PLAN, "true")),
+ collation = collation)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 97c88d660b002..882bc12a0d29b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -30,6 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{CLASS_NAME, DATA_SOURCE, DATA_SOURCES, PATHS}
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
+import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -261,8 +262,12 @@ case class DataSource(
val isSchemaInferenceEnabled = sparkSession.sessionState.conf.streamingSchemaInference
val isTextSource = providingClass == classOf[text.TextFileFormat]
+ val isSingleVariantColumn = (providingClass == classOf[json.JsonFileFormat] ||
+ providingClass == classOf[csv.CSVFileFormat]) &&
+ caseInsensitiveOptions.contains(DataSourceOptions.SINGLE_VARIANT_COLUMN)
// If the schema inference is disabled, only text sources require schema to be specified
- if (!isSchemaInferenceEnabled && !isTextSource && userSpecifiedSchema.isEmpty) {
+ if (!isSchemaInferenceEnabled && !isTextSource && !isSingleVariantColumn &&
+ userSpecifiedSchema.isEmpty) {
throw QueryExecutionErrors.createStreamingSourceNotSpecifySchemaError()
}
@@ -348,7 +353,7 @@ case class DataSource(
* is considered as a non-streaming file based data source. Since we know
* that files already exist, we don't need to check them again.
*/
- def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
+ def resolveRelation(checkFilesExist: Boolean = true, readOnly: Boolean = false): BaseRelation = {
val relation = (providingInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
@@ -439,7 +444,7 @@ case class DataSource(
SchemaUtils.checkSchemaColumnNameDuplication(
hs.partitionSchema,
equality)
- DataSourceUtils.verifySchema(hs.fileFormat, hs.dataSchema)
+ DataSourceUtils.verifySchema(hs.fileFormat, hs.dataSchema, readOnly)
case _ =>
SchemaUtils.checkSchemaColumnNameDuplication(
relation.schema,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2f6588c3aac35..d2969bab28d62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -21,6 +21,7 @@ import java.util.Locale
import scala.collection.immutable.ListMap
import scala.collection.mutable
+import scala.jdk.CollectionConverters._
import org.apache.hadoop.fs.Path
@@ -256,20 +257,40 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
QualifiedTableName(table.identifier.catalog.get, table.database, table.identifier.table)
val catalog = sparkSession.sessionState.catalog
val dsOptions = DataSourceUtils.generateDatasourceOptions(extraOptions, table)
- catalog.getCachedPlan(qualifiedTableName, () => {
- val dataSource =
- DataSource(
- sparkSession,
- // In older version(prior to 2.1) of Spark, the table schema can be empty and should be
- // inferred at runtime. We should still support it.
- userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema),
- partitionColumns = table.partitionColumnNames,
- bucketSpec = table.bucketSpec,
- className = table.provider.get,
- options = dsOptions,
- catalogTable = Some(table))
- LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
- })
+ val readFileSourceTableCacheIgnoreOptions =
+ SQLConf.get.getConf(SQLConf.READ_FILE_SOURCE_TABLE_CACHE_IGNORE_OPTIONS)
+ catalog.getCachedTable(qualifiedTableName) match {
+ case null =>
+ val dataSource =
+ DataSource(
+ sparkSession,
+ // In older version(prior to 2.1) of Spark, the table schema can be empty and should be
+ // inferred at runtime. We should still support it.
+ userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema),
+ partitionColumns = table.partitionColumnNames,
+ bucketSpec = table.bucketSpec,
+ className = table.provider.get,
+ options = dsOptions,
+ catalogTable = Some(table))
+ val plan = LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
+ catalog.cacheTable(qualifiedTableName, plan)
+ plan
+
+ // If readFileSourceTableCacheIgnoreOptions is false AND
+ // the cached table relation's options differ from the new options:
+ // 1. Create a new HadoopFsRelation with updated options
+ // 2. Return a new LogicalRelation with the updated HadoopFsRelation
+ // This ensures the relation reflects any changes in data source options.
+ // Otherwise, leave the cached table relation as is
+ case r @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _, _)
+ if !readFileSourceTableCacheIgnoreOptions &&
+ (new CaseInsensitiveStringMap(fsRelation.options.asJava) !=
+ new CaseInsensitiveStringMap(dsOptions.asJava)) =>
+ val newFsRelation = fsRelation.copy(options = dsOptions)(sparkSession)
+ r.copy(relation = newFsRelation)
+
+ case other => other
+ }
}
private def getStreamingRelation(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
index 81eadcc263c61..3e66b97f61a63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
@@ -91,9 +91,14 @@ object DataSourceUtils extends PredicateHelper {
* Verify if the schema is supported in datasource. This verification should be done
* in a driver side.
*/
- def verifySchema(format: FileFormat, schema: StructType): Unit = {
+ def verifySchema(format: FileFormat, schema: StructType, readOnly: Boolean = false): Unit = {
schema.foreach { field =>
- if (!format.supportDataType(field.dataType)) {
+ val supported = if (readOnly) {
+ format.supportReadDataType(field.dataType)
+ } else {
+ format.supportDataType(field.dataType)
+ }
+ if (!supported) {
throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError(format.toString, field)
}
}
@@ -122,7 +127,7 @@ object DataSourceUtils extends PredicateHelper {
private def getRebaseSpec(
lookupFileMeta: String => String,
- modeByConfig: String,
+ modeByConfig: LegacyBehaviorPolicy.Value,
minVersion: String,
metadataKey: String): RebaseSpec = {
val policy = if (Utils.isTesting &&
@@ -140,7 +145,7 @@ object DataSourceUtils extends PredicateHelper {
} else {
LegacyBehaviorPolicy.CORRECTED
}
- }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig))
+ }.getOrElse(modeByConfig)
}
policy match {
case LegacyBehaviorPolicy.LEGACY =>
@@ -151,7 +156,7 @@ object DataSourceUtils extends PredicateHelper {
def datetimeRebaseSpec(
lookupFileMeta: String => String,
- modeByConfig: String): RebaseSpec = {
+ modeByConfig: LegacyBehaviorPolicy.Value): RebaseSpec = {
getRebaseSpec(
lookupFileMeta,
modeByConfig,
@@ -161,7 +166,7 @@ object DataSourceUtils extends PredicateHelper {
def int96RebaseSpec(
lookupFileMeta: String => String,
- modeByConfig: String): RebaseSpec = {
+ modeByConfig: LegacyBehaviorPolicy.Value): RebaseSpec = {
getRebaseSpec(
lookupFileMeta,
modeByConfig,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
index f82da44e73031..d3078740b819c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
@@ -182,6 +182,13 @@ trait FileFormat {
*/
def supportDataType(dataType: DataType): Boolean = true
+ /**
+ * Returns whether this format supports the given [[DataType]] in the read-only path.
+ * By default, it is the same as `supportDataType`. In certain file formats, it can allow more
+ * data types than `supportDataType`. At this point, only `CSVFileFormat` overrides it.
+ */
+ def supportReadDataType(dataType: DataType): Boolean = supportDataType(dataType)
+
/**
* Returns whether this format supports the given filed name in read/write path.
* By default all field name is supported.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
index 725b4a2332576..3c65ef139ea0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources
-import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NormalizeableRelation}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
@@ -45,7 +45,8 @@ case class LogicalRelation(
extends LeafNode
with StreamSourceAwareLogicalPlan
with MultiInstanceRelation
- with ExposesMetadataColumns {
+ with ExposesMetadataColumns
+ with NormalizeableRelation {
// Only care about relation when canonicalizing.
override def doCanonicalize(): LogicalPlan = copy(
@@ -101,6 +102,13 @@ case class LogicalRelation(
override def withStream(stream: SparkDataStream): LogicalRelation = copy(stream = Some(stream))
override def getStream: Option[SparkDataStream] = stream
+
+ /**
+ * Minimally normalizes this [[LogicalRelation]] to make it comparable in [[NormalizePlan]].
+ */
+ override def normalize(): LogicalPlan = {
+ copy(catalogTable = catalogTable.map(CatalogTable.normalize))
+ }
}
object LogicalRelation {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 402b70065d8e6..1bc4645dfc434 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.getPartitionValueString
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimeFormatter, TimestampFormatter}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
@@ -66,6 +66,7 @@ object PartitionSpec {
object PartitioningUtils extends SQLConfHelper {
+ val timePartitionPattern = "HH:mm:ss[.SSSSSS]"
val timestampPartitionPattern = "yyyy-MM-dd HH:mm:ss[.S]"
case class TypedPartValue(value: String, dataType: DataType)
@@ -145,10 +146,11 @@ object PartitioningUtils extends SQLConfHelper {
timestampPartitionPattern,
zoneId,
isParsing = true)
+ val timeFormatter = TimeFormatter(timePartitionPattern, isParsing = true)
// First, we need to parse every partition's path and see if we can find partition values.
val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes,
- validatePartitionColumns, zoneId, dateFormatter, timestampFormatter)
+ validatePartitionColumns, zoneId, dateFormatter, timestampFormatter, timeFormatter)
}.unzip
// We create pairs of (path -> path's partition value) here
@@ -240,7 +242,8 @@ object PartitioningUtils extends SQLConfHelper {
validatePartitionColumns: Boolean,
zoneId: ZoneId,
dateFormatter: DateFormatter,
- timestampFormatter: TimestampFormatter): (Option[PartitionValues], Option[Path]) = {
+ timestampFormatter: TimestampFormatter,
+ timeFormatter: TimeFormatter): (Option[PartitionValues], Option[Path]) = {
val columns = ArrayBuffer.empty[(String, TypedPartValue)]
// Old Hadoop versions don't have `Path.isRoot`
var finished = path.getParent == null
@@ -262,7 +265,7 @@ object PartitioningUtils extends SQLConfHelper {
// Once we get the string, we try to parse it and find the partition column and value.
val maybeColumn =
parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes,
- zoneId, dateFormatter, timestampFormatter)
+ zoneId, dateFormatter, timestampFormatter, timeFormatter)
maybeColumn.foreach(columns += _)
// Now, we determine if we should stop.
@@ -298,7 +301,8 @@ object PartitioningUtils extends SQLConfHelper {
userSpecifiedDataTypes: Map[String, DataType],
zoneId: ZoneId,
dateFormatter: DateFormatter,
- timestampFormatter: TimestampFormatter): Option[(String, TypedPartValue)] = {
+ timestampFormatter: TimestampFormatter,
+ timeFormatter: TimeFormatter): Option[(String, TypedPartValue)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
None
@@ -319,7 +323,8 @@ object PartitioningUtils extends SQLConfHelper {
typeInference,
zoneId,
dateFormatter,
- timestampFormatter)
+ timestampFormatter,
+ timeFormatter)
}
Some(columnName -> TypedPartValue(rawColumnValue, dataType))
}
@@ -427,22 +432,23 @@ object PartitioningUtils extends SQLConfHelper {
/**
* Converts a string to a [[Literal]] with automatic type inference. Currently only supports
* [[NullType]], [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]]
- * [[TimestampType]], and [[StringType]].
+ * [[TimestampType]], [[TimeType]] and [[StringType]].
*
* When resolving conflicts, it follows the table below:
*
- * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
- * | InputA \ InputB | NullType | IntegerType | LongType | DecimalType(38,0)* | DoubleType | DateType | TimestampType | StringType |
- * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
- * | NullType | NullType | IntegerType | LongType | DecimalType(38,0) | DoubleType | DateType | TimestampType | StringType |
- * | IntegerType | IntegerType | IntegerType | LongType | DecimalType(38,0) | DoubleType | StringType | StringType | StringType |
- * | LongType | LongType | LongType | LongType | DecimalType(38,0) | StringType | StringType | StringType | StringType |
- * | DecimalType(38,0)* | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | StringType | StringType | StringType | StringType |
- * | DoubleType | DoubleType | DoubleType | StringType | StringType | DoubleType | StringType | StringType | StringType |
- * | DateType | DateType | StringType | StringType | StringType | StringType | DateType | TimestampType | StringType |
- * | TimestampType | TimestampType | StringType | StringType | StringType | StringType | TimestampType | TimestampType | StringType |
- * | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType |
- * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
+ * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+------------+
+ * | InputA \ InputB | NullType | IntegerType | LongType | DecimalType(38,0)* | DoubleType | DateType | TimestampType | StringType | TimeType |
+ * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+------------+
+ * | NullType | NullType | IntegerType | LongType | DecimalType(38,0) | DoubleType | DateType | TimestampType | StringType | TimeType |
+ * | IntegerType | IntegerType | IntegerType | LongType | DecimalType(38,0) | DoubleType | StringType | StringType | StringType | StringType |
+ * | LongType | LongType | LongType | LongType | DecimalType(38,0) | StringType | StringType | StringType | StringType | StringType |
+ * | DecimalType(38,0)* | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | StringType | StringType | StringType | StringType | StringType |
+ * | DoubleType | DoubleType | DoubleType | StringType | StringType | DoubleType | StringType | StringType | StringType | StringType |
+ * | DateType | DateType | StringType | StringType | StringType | StringType | DateType | TimestampType | StringType | StringType |
+ * | TimeType | TimeType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | TimeType |
+ * | TimestampType | TimestampType | StringType | StringType | StringType | StringType | TimestampType | TimestampType | StringType | StringType |
+ * | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType | StringType |
+ * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+------------+
* Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other
* combinations of scales and precisions because currently we only infer decimal type like
* `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type.
@@ -453,7 +459,8 @@ object PartitioningUtils extends SQLConfHelper {
typeInference: Boolean,
zoneId: ZoneId,
dateFormatter: DateFormatter,
- timestampFormatter: TimestampFormatter): DataType = {
+ timestampFormatter: TimestampFormatter,
+ timeFormatter: TimeFormatter): DataType = {
val decimalTry = Try {
// `BigDecimal` conversion can fail when the `field` is not a form of number.
val bigDecimal = new JBigDecimal(raw)
@@ -499,6 +506,20 @@ object PartitioningUtils extends SQLConfHelper {
timestampType
}
+ val timeTry = Try {
+ val unescapedRaw = unescapePathName(raw)
+ // try and parse the time, if no exception occurs this is a candidate to be resolved as
+ // TimeType
+ timeFormatter.parse(unescapedRaw)
+ // We need to check that we can cast the raw string since we later can use Cast to get
+ // the partition values with the right DataType (see
+ // org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning)
+ val timeValue = Cast(Literal(unescapedRaw), TimeType(), Some(zoneId.getId)).eval()
+ // Disallow TimeType if the cast returned null
+ require(timeValue != null)
+ TimeType()
+ }
+
if (typeInference) {
// First tries integral types
Try({ Integer.parseInt(raw); IntegerType })
@@ -509,6 +530,7 @@ object PartitioningUtils extends SQLConfHelper {
// Then falls back to date/timestamp types
.orElse(timestampTry)
.orElse(dateTry)
+ .orElse(timeTry)
// Then falls back to string
.getOrElse {
if (raw == DEFAULT_PARTITION_NAME) NullType else StringType
@@ -534,6 +556,7 @@ object PartitioningUtils extends SQLConfHelper {
case _: DecimalType => Literal(new JBigDecimal(value)).value
case DateType =>
Cast(Literal(value), DateType, Some(zoneId.getId)).eval()
+ case tt: TimeType => Cast(Literal(unescapePathName(value)), tt).eval()
// Timestamp types
case dt if AnyTimestampType.acceptsType(dt) =>
Try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index e9cc23c6a5bab..5960cf8c38ced 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -289,13 +289,13 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation: LogicalRelation,
hadoopFsRelation: HadoopFsRelation): LogicalPlan = {
val variants = new VariantInRelation
- val defaultValues = ResolveDefaultColumns.existenceDefaultValues(hadoopFsRelation.schema)
- // I'm not aware of any case that an attribute `relation.output` can have a different data type
- // than the corresponding field in `hadoopFsRelation.schema`. Other code seems to prefer using
- // the data type in `hadoopFsRelation.schema`, let's also stick to it.
- val schemaWithAttributes = hadoopFsRelation.schema.fields.zip(relation.output)
- for (((f, attr), defaultValue) <- schemaWithAttributes.zip(defaultValues)) {
- variants.addVariantFields(attr.exprId, f.dataType, defaultValue, Nil)
+
+ val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema,
+ hadoopFsRelation.sparkSession.sessionState.analyzer.resolver)
+ val defaultValues = ResolveDefaultColumns.existenceDefaultValues(StructType(
+ schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
+ for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
+ variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
}
if (variants.mapping.isEmpty) return originalPlan
@@ -304,24 +304,28 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
// `collectRequestedFields` may have removed all variant columns.
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
- val (newFields, newOutput) = schemaWithAttributes.map {
- case (f, attr) =>
- if (variants.mapping.get(attr.exprId).exists(_.nonEmpty)) {
- val newType = variants.rewriteType(attr.exprId, f.dataType, Nil)
- val newAttr = AttributeReference(f.name, newType, f.nullable, f.metadata)()
- (f.copy(dataType = newType), newAttr)
- } else {
- (f, attr)
- }
- }.unzip
+ val attributeMap = schemaAttributes.map { a =>
+ if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
+ val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
+ val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
+ qualifier = a.qualifier)
+ (a.exprId, newAttr)
+ } else {
+ // `relation.resolve` actually returns `Seq[AttributeReference]`, although the return type
+ // is `Seq[Attribute]`.
+ (a.exprId, a.asInstanceOf[AttributeReference])
+ }
+ }.toMap
+ val newFields = schemaAttributes.map { a =>
+ val dataType = attributeMap(a.exprId).dataType
+ StructField(a.name, dataType, a.nullable, a.metadata)
+ }
+ val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, a))
val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = StructType(newFields))(
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)
- val attributeMap = relation.output.zip(newOutput).map {
- case (oldAttr, newAttr) => oldAttr.exprId -> newAttr
- }.toMap
val withFilter = if (filters.nonEmpty) {
Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
index 54c100282e2db..87326615f3266 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
@@ -55,7 +55,7 @@ import org.apache.spark.util.SerializableConfiguration
* .load("/path/to/fileDir");
* }}}
*/
-class BinaryFileFormat extends FileFormat with DataSourceRegister {
+case class BinaryFileFormat() extends FileFormat with DataSourceRegister {
import BinaryFileFormat._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 6196bef106fa5..c6b9764bee2c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
@@ -68,10 +68,14 @@ abstract class CSVDataSource extends Serializable {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): Option[StructType] = {
- if (inputPaths.nonEmpty) {
- Some(infer(sparkSession, inputPaths, parsedOptions))
- } else {
- None
+ parsedOptions.singleVariantColumn match {
+ case Some(columnName) => Some(StructType(Array(StructField(columnName, VariantType))))
+ case None =>
+ if (inputPaths.nonEmpty) {
+ Some(infer(sparkSession, inputPaths, parsedOptions))
+ } else {
+ None
+ }
}
}
@@ -89,6 +93,31 @@ object CSVDataSource extends Logging {
TextInputCSVDataSource
}
}
+
+ /**
+ * Returns a function that sets the header column names used in singleVariantColumn mode. The
+ * returned function takes an optional input, which is the header column names potentially read by
+ * `CSVHeaderChecker`. The function only needs to read the file when the input is empty (e.g.,
+ * `CSVHeaderChecker` won't read anything when the partition is not at the file start).
+ *
+ * We need to return a function here instead of letting `CSVHeaderChecker` call this function
+ * directly, because this package (also the `CSVUtils` class) depends on `CSVHeaderChecker`.
+ */
+ def setHeaderForSingleVariantColumn(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: UnivocityParser): Option[Option[Array[String]] => Unit] =
+ if (parser.options.needHeaderForSingleVariantColumn) {
+ Some(headerColumnNames => {
+ parser.headerColumnNames = headerColumnNames.orElse {
+ CSVUtils.readHeaderLine(file.toPath, parser.options, conf).map { line =>
+ new CsvParser(parser.options.asParserSettings).parseLine(line)
+ }
+ }
+ })
+ } else {
+ None
+ }
}
object TextInputCSVDataSource extends CSVDataSource {
@@ -110,6 +139,8 @@ object TextInputCSVDataSource extends CSVDataSource {
}
}
+ headerChecker.setHeaderForSingleVariantColumn =
+ CSVDataSource.setHeaderForSingleVariantColumn(conf, file, parser)
UnivocityParser.parseIterator(lines, parser, headerChecker, requiredSchema)
}
@@ -187,6 +218,8 @@ object MultiLineCSVDataSource extends CSVDataSource with Logging {
parser: UnivocityParser,
headerChecker: CSVHeaderChecker,
requiredSchema: StructType): Iterator[InternalRow] = {
+ headerChecker.setHeaderForSingleVariantColumn =
+ CSVDataSource.setHeaderForSingleVariantColumn(conf, file, parser)
UnivocityParser.parseStream(
CodecStreams.createInputStreamWithCloseResource(conf, file.toPath),
parser,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index b2b99e2d0f4ea..a65f7bbbeba50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.SerializableConfiguration
/**
* Provides access to CSV data from pure SQL statements.
*/
-class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
+case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
override def shortName(): String = "csv"
@@ -68,6 +68,14 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
+ // This is a defensive check to ensure the schema doesn't have variant. It shouldn't be
+ // triggered if other part of the code is correct because `supportDataType` doesn't allow
+ // variant (in case the user is not using `supportDataType/supportReadDataType` correctly).
+ dataSchema.foreach { field =>
+ if (!supportDataType(field.dataType, allowVariant = false)) {
+ throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError("CSV", field)
+ }
+ }
val conf = job.getConfiguration
val csvOptions = new CSVOptions(
options,
@@ -150,13 +158,20 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def toString: String = "CSV"
- override def hashCode(): Int = getClass.hashCode()
+ /**
+ * Allow reading variant from CSV, but don't allow writing variant into CSV. This is because the
+ * written data (the string representation of variant) may not be read back as the same variant.
+ */
+ override def supportDataType(dataType: DataType): Boolean =
+ supportDataType(dataType, allowVariant = false)
- override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
+ override def supportReadDataType(dataType: DataType): Boolean =
+ supportDataType(dataType, allowVariant = true)
- override def supportDataType(dataType: DataType): Boolean = dataType match {
- case _: VariantType => false
+ private def supportDataType(dataType: DataType, allowVariant: Boolean): Boolean = dataType match {
+ case _: VariantType => allowVariant
+ case _: TimeType => false
case _: AtomicType => true
case udt: UserDefinedType[_] => supportDataType(udt.sqlType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
index 1a48b81fd7e64..853235bbd3895 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -17,10 +17,17 @@
package org.apache.spark.sql.execution.datasources.csv
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.util.LineReader
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.csv.CSVExprUtils
import org.apache.spark.sql.catalyst.csv.CSVOptions
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.CodecStreams
import org.apache.spark.sql.functions._
object CSVUtils {
@@ -130,4 +137,40 @@ object CSVUtils {
def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] =
CSVExprUtils.filterCommentAndEmpty(iter, options)
+
+ def readHeaderLine(filePath: Path, options: CSVOptions, conf: Configuration): Option[String] = {
+ val inputStream = CodecStreams.createInputStream(conf, filePath)
+ try {
+ val lines = new Iterator[String] {
+ private val in = options.lineSeparatorInRead match {
+ case Some(sep) => new LineReader(inputStream, sep)
+ case _ => new LineReader(inputStream)
+ }
+ private val text = new Text()
+ private var finished = false
+ private var hasValue = false
+
+ override def hasNext: Boolean = {
+ if (!finished && !hasValue) {
+ val bytesRead = in.readLine(text)
+ finished = bytesRead == 0
+ hasValue = !finished
+ }
+ !finished
+ }
+
+ override def next(): String = {
+ if (!hasValue) {
+ throw QueryExecutionErrors.endOfStreamError()
+ }
+ hasValue = false
+ new String(text.getBytes, 0, text.getLength, options.charset)
+ }
+ }
+ val filteredLines = CSVUtils.filterCommentAndEmpty(lines, options)
+ filteredLines.buffered.headOption
+ } finally {
+ inputStream.close()
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 651c29d097663..8112cf1c80ef9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -203,7 +203,11 @@ object JdbcUtils extends Logging with SQLConfHelper {
case java.sql.Types.DECIMAL | java.sql.Types.NUMERIC if scale < 0 =>
DecimalType.bounded(precision - scale, 0)
case java.sql.Types.DECIMAL | java.sql.Types.NUMERIC =>
- DecimalPrecisionTypeCoercion.bounded(precision, scale)
+ DecimalPrecisionTypeCoercion.bounded(
+ // A safeguard in case the JDBC scale is larger than the precision that is not supported
+ // by Spark.
+ math.max(precision, scale),
+ scale)
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
case java.sql.Types.INTEGER => if (signed) IntegerType else LongType
@@ -314,6 +318,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
metadata.putBoolean("isSigned", isSigned)
metadata.putBoolean("isTimestampNTZ", isTimestampNTZ)
metadata.putLong("scale", fieldScale)
+ metadata.putString("jdbcClientType", typeName)
dialect.updateExtraColumnMeta(conn, rsmd, i + 1, metadata)
val columnType =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 3a4ca99fc95a4..bedf5ec62e4ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -61,10 +61,14 @@ abstract class JsonDataSource extends Serializable {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): Option[StructType] = {
- if (inputPaths.nonEmpty) {
- Some(infer(sparkSession, inputPaths, parsedOptions))
- } else {
- None
+ parsedOptions.singleVariantColumn match {
+ case Some(columnName) => Some(StructType(Array(StructField(columnName, VariantType))))
+ case None =>
+ if (inputPaths.nonEmpty) {
+ Some(infer(sparkSession, inputPaths, parsedOptions))
+ } else {
+ None
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 6174c017f6047..e38ca137b162d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
-class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
+case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister {
override val shortName: String = "json"
override def isSplitable(
@@ -55,10 +55,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
- parsedOptions.singleVariantColumn match {
- case Some(columnName) => Some(StructType(Array(StructField(columnName, VariantType))))
- case None => JsonDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
- }
+ JsonDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
}
override def prepareWrite(
@@ -131,13 +128,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def toString: String = "JSON"
- override def hashCode(): Int = getClass.hashCode()
-
- override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat]
-
override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: VariantType => true
+ case _: TimeType => false
case _: AtomicType => true
case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 5513359fdaa31..eea446a492804 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -246,6 +246,7 @@ class OrcFileFormat
override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: VariantType => false
+ case _: TimeType => false
case _: AtomicType => true
case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 3dc9ddf386f10..565742671b9cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -21,7 +21,7 @@ import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float =
import java.math.{BigDecimal => JBigDecimal}
import java.nio.charset.StandardCharsets.UTF_8
import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate, Period}
+import java.time.{Duration, Instant, LocalDate, LocalTime, Period}
import java.util.HashSet
import java.util.Locale
@@ -149,6 +149,8 @@ class ParquetFilters(
ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS), INT64, 0)
private val ParquetTimestampMillisType =
ParquetSchemaType(LogicalTypeAnnotation.timestampType(true, TimeUnit.MILLIS), INT64, 0)
+ private val ParquetTimeType =
+ ParquetSchemaType(LogicalTypeAnnotation.timeType(false, TimeUnit.MICROS), INT64, 0)
private def dateToDays(date: Any): Int = {
val gregorianDays = date match {
@@ -173,6 +175,10 @@ class ParquetFilters(
}
}
+ private def localTimeToMicros(v: Any): JLong = {
+ DateTimeUtils.localTimeToMicros(v.asInstanceOf[LocalTime])
+ }
+
private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue()
private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue()
@@ -207,6 +213,7 @@ class ParquetFilters(
private def toLongValue(v: Any): JLong = v match {
case d: Duration => IntervalUtils.durationToMicros(d)
+ case lt: LocalTime => DateTimeUtils.localTimeToMicros(lt)
case l => l.asInstanceOf[JLong]
}
@@ -244,6 +251,10 @@ class ParquetFilters(
(n: Array[String], v: Any) => FilterApi.eq(
longColumn(n),
Option(v).map(timestampToMillis).orNull)
+ case ParquetTimeType =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ longColumn(n),
+ Option(v).map(localTimeToMicros).orNull)
case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
(n: Array[String], v: Any) => FilterApi.eq(
@@ -293,6 +304,10 @@ class ParquetFilters(
(n: Array[String], v: Any) => FilterApi.notEq(
longColumn(n),
Option(v).map(timestampToMillis).orNull)
+ case ParquetTimeType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ longColumn(n),
+ Option(v).map(localTimeToMicros).orNull)
case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
(n: Array[String], v: Any) => FilterApi.notEq(
@@ -333,6 +348,8 @@ class ParquetFilters(
(n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v))
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v))
+ case ParquetTimeType =>
+ (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), localTimeToMicros(v))
case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
(n: Array[String], v: Any) =>
@@ -370,6 +387,8 @@ class ParquetFilters(
(n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v))
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v))
+ case ParquetTimeType =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), localTimeToMicros(v))
case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
(n: Array[String], v: Any) =>
@@ -407,6 +426,8 @@ class ParquetFilters(
(n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v))
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v))
+ case ParquetTimeType =>
+ (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), localTimeToMicros(v))
case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
(n: Array[String], v: Any) =>
@@ -444,6 +465,8 @@ class ParquetFilters(
(n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v))
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v))
+ case ParquetTimeType =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), localTimeToMicros(v))
case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
(n: Array[String], v: Any) =>
@@ -533,6 +556,14 @@ class ParquetFilters(
}
FilterApi.in(longColumn(n), set)
+ case ParquetTimeType =>
+ (n: Array[String], values: Array[Any]) =>
+ val set = new HashSet[JLong]()
+ for (value <- values) {
+ set.add(Option(value).map(localTimeToMicros).orNull)
+ }
+ FilterApi.in(longColumn(n), set)
+
case ParquetSchemaType(_: DecimalLogicalTypeAnnotation, INT32, _) if pushDownDecimal =>
(n: Array[String], values: Array[Any]) =>
val set = new HashSet[Integer]()
@@ -620,7 +651,8 @@ class ParquetFilters(
case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= Int.MaxValue
case _ => false
}
- case ParquetLongType => value.isInstanceOf[JLong] || value.isInstanceOf[Duration]
+ case ParquetLongType =>
+ value.isInstanceOf[JLong] || value.isInstanceOf[Duration]
case ParquetFloatType => value.isInstanceOf[JFloat]
case ParquetDoubleType => value.isInstanceOf[JDouble]
case ParquetStringType => value.isInstanceOf[String]
@@ -629,6 +661,7 @@ class ParquetFilters(
value.isInstanceOf[Date] || value.isInstanceOf[LocalDate]
case ParquetTimestampMicrosType | ParquetTimestampMillisType =>
value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant]
+ case ParquetTimeType => value.isInstanceOf[LocalTime]
case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT32, _) =>
isDecimalMatched(value, decimalType)
case ParquetSchemaType(decimalType: DecimalLogicalTypeAnnotation, INT64, _) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
index e795d156d7646..eaedd99d8628c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
@@ -24,7 +24,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
/**
* Options for the Parquet data source.
@@ -74,14 +74,15 @@ class ParquetOptions(
/**
* The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads.
*/
- def datetimeRebaseModeInRead: String = parameters
+ def datetimeRebaseModeInRead: LegacyBehaviorPolicy.Value = parameters
.get(DATETIME_REBASE_MODE)
+ .map(LegacyBehaviorPolicy.withName)
.getOrElse(sqlConf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_READ))
/**
* The rebasing mode for INT96 timestamp values in reads.
*/
- def int96RebaseModeInRead: String = parameters
- .get(INT96_REBASE_MODE)
+ def int96RebaseModeInRead: LegacyBehaviorPolicy.Value = parameters
+ .get(INT96_REBASE_MODE).map(LegacyBehaviorPolicy.withName)
.getOrElse(sqlConf.getConf(SQLConf.PARQUET_INT96_REBASE_MODE_IN_READ))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 550c2af43a706..0927f5c3c963c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -481,6 +481,16 @@ private[parquet] class ParquetRowConverter(
}
}
+ case _: TimeType
+ if parquetType.getLogicalTypeAnnotation.isInstanceOf[TimeLogicalTypeAnnotation] &&
+ parquetType.getLogicalTypeAnnotation
+ .asInstanceOf[TimeLogicalTypeAnnotation].getUnit == TimeUnit.MICROS =>
+ new ParquetPrimitiveConverter(updater) {
+ override def addLong(value: Long): Unit = {
+ this.updater.setLong(value)
+ }
+ }
+
// A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor
// annotated by `LIST` or `MAP` should be interpreted as a required list of required
// elements where the element type is the type of the field.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index daeb8e88a924b..76073c3b050bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -284,7 +284,11 @@ class ParquetToSparkSchemaConverter(
case timestamp: TimestampLogicalTypeAnnotation
if timestamp.getUnit == TimeUnit.NANOS && nanosAsLong =>
LongType
- case _ => illegalType()
+ case time: TimeLogicalTypeAnnotation
+ if time.getUnit == TimeUnit.MICROS && !time.isAdjustedToUTC =>
+ TimeType(TimeType.MICROS_PRECISION)
+ case _ =>
+ illegalType()
}
case INT96 =>
@@ -578,6 +582,10 @@ class SparkToParquetSchemaConverter(
Types.primitive(INT32, repetition)
.as(LogicalTypeAnnotation.dateType()).named(field.name)
+ case _: TimeType =>
+ Types.primitive(INT64, repetition)
+ .as(LogicalTypeAnnotation.timeType(false, TimeUnit.MICROS)).named(field.name)
+
// NOTE: Spark SQL can write timestamp values to Parquet using INT96, TIMESTAMP_MICROS or
// TIMESTAMP_MILLIS. TIMESTAMP_MICROS is recommended but INT96 is the default to keep the
// behavior same as before.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index 35eb57a2e4fb2..4022f7ea30032 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -83,8 +83,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
private val decimalBuffer =
new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION))
- private val datetimeRebaseMode = LegacyBehaviorPolicy.withName(
- SQLConf.get.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE))
+ private val datetimeRebaseMode = SQLConf.get.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE)
private val dateRebaseFunc = DataSourceUtils.createDateRebaseFuncInWrite(
datetimeRebaseMode, "Parquet")
@@ -92,8 +91,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
private val timestampRebaseFunc = DataSourceUtils.createTimestampRebaseFuncInWrite(
datetimeRebaseMode, "Parquet")
- private val int96RebaseMode = LegacyBehaviorPolicy.withName(
- SQLConf.get.getConf(SQLConf.PARQUET_INT96_REBASE_MODE_IN_WRITE))
+ private val int96RebaseMode = SQLConf.get.getConf(SQLConf.PARQUET_INT96_REBASE_MODE_IN_WRITE)
private val int96RebaseFunc = DataSourceUtils.createTimestampRebaseFuncInWrite(
int96RebaseMode, "Parquet INT96")
@@ -211,7 +209,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addInteger(row.getInt(ordinal))
- case LongType | _: DayTimeIntervalType =>
+ case LongType | _: DayTimeIntervalType | _: TimeType =>
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addLong(row.getLong(ordinal))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 6420d3ab374e5..4e38e9acb5531 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -81,8 +81,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
// We put the resolved relation into the [[AnalyzerBridgeState]] for
// it to be later reused by the single-pass [[Resolver]] to avoid resolving the
// relation metadata twice.
- AnalysisContext.get.getSinglePassResolverBridgeState.map { bridgeState =>
- bridgeState.relationsWithResolvedMetadata.put(unresolvedRelation, resolvedRelation)
+ AnalysisContext.get.getSinglePassResolverBridgeState.foreach { bridgeState =>
+ bridgeState.addUnresolvedRelation(unresolvedRelation, resolvedRelation)
}
case _ =>
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index 3f2024126717d..7af239f99d45e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -37,7 +37,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
* A data source for reading text files. The text files must be encoded as UTF-8.
*/
-class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
+case class TextFileFormat() extends TextBasedFileFormat with DataSourceRegister {
override def shortName(): String = "text"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala
index f55fbafe11ddb..f0812245bcce0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.TableSpec
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, TableCatalog, TableInfo}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -43,7 +43,12 @@ case class CreateTableExec(
override protected def run(): Seq[InternalRow] = {
if (!catalog.tableExists(identifier)) {
try {
- catalog.createTable(identifier, columns, partitioning.toArray, tableProperties.asJava)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(partitioning.toArray)
+ .withProperties(tableProperties.asJava)
+ .build()
+ catalog.createTable(identifier, tableInfo)
} catch {
case _: TableAlreadyExistsException if ignoreIfExists =>
logWarning(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index bca3146df2766..d1b6d65509e2c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder}
+import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, ResolveTableConstraints, V2ExpressionBuilder}
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
@@ -170,6 +170,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
val continuousStream = r.stream.asInstanceOf[ContinuousStream]
val scanExec = ContinuousScanExec(r.output, r.scan, continuousStream, r.startOffset.get)
+ // initialize partitions
+ scanExec.inputPartitions
// Add a Project here to make sure we produce unsafe rows.
DataSourceV2Strategy.withProjectAndFilter(p, f, scanExec, !scanExec.supportsColumnar) :: Nil
@@ -183,11 +185,14 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case c @ CreateTable(ResolvedIdentifier(catalog, ident), columns, partitioning,
tableSpec: TableSpec, ifNotExists) =>
- ResolveDefaultColumns.validateCatalogForDefaultValue(columns, catalog.asTableCatalog, ident)
+ val tableCatalog = catalog.asTableCatalog
+ ResolveDefaultColumns.validateCatalogForDefaultValue(columns, tableCatalog, ident)
+ ResolveTableConstraints.validateCatalogForTableConstraint(
+ tableSpec.constraints, tableCatalog, ident)
val statementType = "CREATE TABLE"
GeneratedColumn.validateGeneratedColumns(
- c.tableSchema, catalog.asTableCatalog, ident, statementType)
- IdentityColumn.validateIdentityColumn(c.tableSchema, catalog.asTableCatalog, ident)
+ c.tableSchema, tableCatalog, ident, statementType)
+ IdentityColumn.validateIdentityColumn(c.tableSchema, tableCatalog, ident)
CreateTableExec(
catalog.asTableCatalog,
@@ -213,11 +218,14 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case c @ ReplaceTable(
ResolvedIdentifier(catalog, ident), columns, parts, tableSpec: TableSpec, orCreate) =>
- ResolveDefaultColumns.validateCatalogForDefaultValue(columns, catalog.asTableCatalog, ident)
+ val tableCatalog = catalog.asTableCatalog
+ ResolveDefaultColumns.validateCatalogForDefaultValue(columns, tableCatalog, ident)
+ ResolveTableConstraints.validateCatalogForTableConstraint(
+ tableSpec.constraints, tableCatalog, ident)
val statementType = "REPLACE TABLE"
GeneratedColumn.validateGeneratedColumns(
- c.tableSchema, catalog.asTableCatalog, ident, statementType)
- IdentityColumn.validateIdentityColumn(c.tableSchema, catalog.asTableCatalog, ident)
+ c.tableSchema, tableCatalog, ident, statementType)
+ IdentityColumn.validateIdentityColumn(c.tableSchema, tableCatalog, ident)
val v2Columns = columns.map(_.toV2Column(statementType)).toArray
catalog match {
@@ -225,7 +233,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
AtomicReplaceTableExec(staging, ident, v2Columns, parts,
qualifyLocInTableSpec(tableSpec), orCreate = orCreate, invalidateCache) :: Nil
case _ =>
- ReplaceTableExec(catalog.asTableCatalog, ident, v2Columns, parts,
+ ReplaceTableExec(tableCatalog, ident, v2Columns, parts,
qualifyLocInTableSpec(tableSpec), orCreate = orCreate, invalidateCache) :: Nil
}
@@ -409,9 +417,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case DropNamespace(ResolvedNamespace(catalog, ns, _), ifExists, cascade) =>
DropNamespaceExec(catalog, ns, ifExists, cascade) :: Nil
- case ShowNamespaces(ResolvedNamespace(catalog, ns, _), pattern, output) =>
- ShowNamespacesExec(output, catalog.asNamespaceCatalog, ns, pattern) :: Nil
-
case ShowTables(ResolvedNamespace(catalog, ns, _), pattern, output) =>
ShowTablesExec(output, catalog.asTableCatalog, ns, pattern) :: Nil
@@ -531,6 +536,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case a: AlterTableCommand if a.table.resolved =>
val table = a.table.asInstanceOf[ResolvedTable]
+ ResolveTableConstraints.validateCatalogForTableChange(
+ a.changes, table.catalog, table.identifier)
AlterTableExec(table.catalog, table.identifier, a.changes) :: Nil
case CreateIndex(ResolvedTable(_, _, table, _),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala
index 894a3a10d4193..51f5c848bd27b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.TableSpec
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -48,7 +48,12 @@ case class ReplaceTableExec(
} else if (!orCreate) {
throw QueryCompilationErrors.cannotReplaceMissingTableError(ident)
}
- catalog.createTable(ident, columns, partitioning.toArray, tableProperties.asJava)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(partitioning.toArray)
+ .withProperties(tableProperties.asJava)
+ .build()
+ catalog.createTable(ident, tableInfo)
Seq.empty
}
@@ -75,12 +80,20 @@ case class AtomicReplaceTableExec(
invalidateCache(catalog, table, identifier)
}
val staged = if (orCreate) {
- catalog.stageCreateOrReplace(
- identifier, columns, partitioning.toArray, tableProperties.asJava)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(partitioning.toArray)
+ .withProperties(tableProperties.asJava)
+ .build()
+ catalog.stageCreateOrReplace(identifier, tableInfo)
} else if (catalog.tableExists(identifier)) {
try {
- catalog.stageReplace(
- identifier, columns, partitioning.toArray, tableProperties.asJava)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(partitioning.toArray)
+ .withProperties(tableProperties.asJava)
+ .build()
+ catalog.stageReplace(identifier, tableInfo)
} catch {
case e: NoSuchTableException =>
throw QueryCompilationErrors.cannotReplaceMissingTableError(identifier, Some(e))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
deleted file mode 100644
index c55c7b9f98544..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.v2
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.util.StringUtils
-import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper
-import org.apache.spark.sql.connector.catalog.SupportsNamespaces
-import org.apache.spark.sql.execution.LeafExecNode
-
-/**
- * Physical plan node for showing namespaces.
- */
-case class ShowNamespacesExec(
- output: Seq[Attribute],
- catalog: SupportsNamespaces,
- namespace: Seq[String],
- pattern: Option[String]) extends V2CommandExec with LeafExecNode {
-
- override protected def run(): Seq[InternalRow] = {
- val namespaces = if (namespace.nonEmpty) {
- catalog.listNamespaces(namespace.toArray)
- } else {
- catalog.listNamespaces()
- }
-
- // Please refer to the rule `KeepLegacyOutputs` for details about legacy command.
- // The legacy SHOW DATABASES command does not quote the database names.
- val isLegacy = output.head.name == "databaseName"
- val namespaceNames = if (isLegacy && namespaces.forall(_.length == 1)) {
- namespaces.map(_.head)
- } else {
- namespaces.map(_.quoted)
- }
-
- val rows = new ArrayBuffer[InternalRow]()
- namespaceNames.map { ns =>
- if (pattern.map(StringUtils.filterPattern(Seq(ns), _).nonEmpty).getOrElse(true)) {
- rows += toCatalystRow(ns)
- }
- }
-
- rows.toSeq
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
index 358f35e11d655..3eadffb8f0ae4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
@@ -69,17 +69,19 @@ sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write {
def write: V1Write
override def run(): Seq[InternalRow] = {
- writeWithV1(write.toInsertableRelation)
- refreshCache()
+ try {
+ writeWithV1(write.toInsertableRelation)
+ refreshCache()
- write.reportDriverMetrics().foreach { customTaskMetric =>
- metrics.get(customTaskMetric.name()).foreach(_.set(customTaskMetric.value()))
- }
-
- val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
+ Nil
+ } finally {
+ write.reportDriverMetrics().foreach { customTaskMetric =>
+ metrics.get(customTaskMetric.name()).foreach(_.set(customTaskMetric.value()))
+ }
- Nil
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index 9d059416766a1..d8e3a7eaf5aca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -45,7 +45,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case a @ AppendData(r: DataSourceV2Relation, query, options, _, None, _) =>
- val writeOptions = mergeOptions(options, r.options.asScala.toMap)
+ val writeOptions = mergeOptions(options, r.options.asCaseSensitiveMap.asScala.toMap)
val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
@@ -63,7 +63,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
}.toArray
val table = r.table
- val writeOptions = mergeOptions(options, r.options.asScala.toMap)
+ val writeOptions = mergeOptions(options, r.options.asCaseSensitiveMap.asScala.toMap)
val writeBuilder = newWriteBuilder(table, writeOptions, query.schema)
val write = writeBuilder match {
case builder: SupportsTruncate if isTruncate(predicates) =>
@@ -79,7 +79,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) =>
val table = r.table
- val writeOptions = mergeOptions(options, r.options.asScala.toMap)
+ val writeOptions = mergeOptions(options, r.options.asCaseSensitiveMap.asScala.toMap)
val writeBuilder = newWriteBuilder(table, writeOptions, query.schema)
val write = writeBuilder match {
case builder: SupportsDynamicOverwrite =>
@@ -93,7 +93,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
case WriteToMicroBatchDataSource(
relationOpt, table, query, queryId, options, outputMode, Some(batchId)) =>
val writeOptions = mergeOptions(
- options, relationOpt.map(r => r.options.asScala.toMap).getOrElse(Map.empty))
+ options,
+ relationOpt.map(r => r.options.asCaseSensitiveMap.asScala.toMap).getOrElse(Map.empty))
val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId = queryId)
val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
@@ -105,14 +106,14 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, projections, _, None) =>
val rowSchema = projections.rowProjection.schema
val metadataSchema = projections.metadataProjection.map(_.schema)
- val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap)
+ val writeOptions = mergeOptions(Map.empty, r.options.asCaseSensitiveMap.asScala.toMap)
val writeBuilder = newWriteBuilder(r.table, writeOptions, rowSchema, metadataSchema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
rd.copy(write = Some(write), query = newQuery)
case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, None) =>
- val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap)
+ val writeOptions = mergeOptions(Map.empty, r.options.asCaseSensitiveMap.asScala.toMap)
val deltaWriteBuilder = newDeltaWriteBuilder(r.table, writeOptions, projections)
val deltaWrite = deltaWriteBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(deltaWrite, query, r.funCatalog)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 016d6b5411acb..2d1964f6a2170 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode}
import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage}
@@ -82,10 +82,13 @@ case class CreateTableAsSelectExec(
}
throw QueryCompilationErrors.tableAlreadyExistsError(ident)
}
- val table = Option(catalog.createTable(
- ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
- partitioning.toArray, properties.asJava)
- ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(getV2Columns(query.schema, catalog.useNullableQuerySchema))
+ .withPartitions(partitioning.toArray)
+ .withProperties(properties.asJava)
+ .build()
+ val table = Option(catalog.createTable(ident, tableInfo))
+ .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
writeToTable(catalog, table, writeOptions, ident, query)
}
}
@@ -120,9 +123,12 @@ case class AtomicCreateTableAsSelectExec(
}
throw QueryCompilationErrors.tableAlreadyExistsError(ident)
}
- val stagedTable = Option(catalog.stageCreate(
- ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
- partitioning.toArray, properties.asJava)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(getV2Columns(query.schema, catalog.useNullableQuerySchema))
+ .withPartitions(partitioning.toArray)
+ .withProperties(properties.asJava)
+ .build()
+ val stagedTable = Option(catalog.stageCreate(ident, tableInfo)
).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
writeToTable(catalog, stagedTable, writeOptions, ident, query)
}
@@ -167,10 +173,13 @@ case class ReplaceTableAsSelectExec(
} else if (!orCreate) {
throw QueryCompilationErrors.cannotReplaceMissingTableError(ident)
}
- val table = Option(catalog.createTable(
- ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
- partitioning.toArray, properties.asJava)
- ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(getV2Columns(query.schema, catalog.useNullableQuerySchema))
+ .withPartitions(partitioning.toArray)
+ .withProperties(properties.asJava)
+ .build()
+ val table = Option(catalog.createTable(ident, tableInfo))
+ .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
writeToTable(catalog, table, writeOptions, ident, query)
}
}
@@ -210,12 +219,20 @@ case class AtomicReplaceTableAsSelectExec(
invalidateCache(catalog, table, ident)
}
val staged = if (orCreate) {
- catalog.stageCreateOrReplace(
- ident, columns, partitioning.toArray, properties.asJava)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(partitioning.toArray)
+ .withProperties(properties.asJava)
+ .build()
+ catalog.stageCreateOrReplace(ident, tableInfo)
} else if (catalog.tableExists(ident)) {
try {
- catalog.stageReplace(
- ident, columns, partitioning.toArray, properties.asJava)
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withPartitions(partitioning.toArray)
+ .withProperties(properties.asJava)
+ .build()
+ catalog.stageReplace(ident, tableInfo)
} catch {
case e: NoSuchTableException =>
throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e))
@@ -356,8 +373,11 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
}.toMap
override protected def run(): Seq[InternalRow] = {
- val writtenRows = writeWithV2(write.toBatch)
- postDriverMetrics()
+ val writtenRows = try {
+ writeWithV2(write.toBatch)
+ } finally {
+ postDriverMetrics()
+ }
refreshCache()
writtenRows
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala
index 715112e352963..b46223db6abb5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala
@@ -130,7 +130,10 @@ class JDBCTableCatalog extends TableCatalog
}
override def loadTable(ident: Identifier): Table = {
- checkNamespace(ident.namespace())
+ if (!tableExists(ident)) {
+ throw QueryCompilationErrors.noSuchTableError(ident)
+ }
+
val optionsWithTableName = new JDBCOptions(
options.parameters + (JDBCOptions.JDBC_TABLE_NAME -> getTableName(ident)))
JdbcUtils.classifyException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala
index edea702587791..7c113c1cb03a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala
@@ -52,6 +52,25 @@ class PythonDataSourceV2 extends TableProvider {
dataSourceInPython
}
+ private var readInfo: PythonDataSourceReadInfo = _
+
+ def getOrCreateReadInfo(
+ shortName: String,
+ options: CaseInsensitiveStringMap,
+ outputSchema: StructType,
+ isStreaming: Boolean
+ ): PythonDataSourceReadInfo = {
+ if (readInfo == null) {
+ val creationResult = getOrCreateDataSourceInPython(shortName, options, Some(outputSchema))
+ readInfo = source.createReadInfoInPython(creationResult, outputSchema, isStreaming)
+ }
+ readInfo
+ }
+
+ def setReadInfo(readInfo: PythonDataSourceReadInfo): Unit = {
+ this.readInfo = readInfo
+ }
+
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
getOrCreateDataSourceInPython(shortName, options, None).schema
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index b3ecfc8bb7f7e..65c71dd4eeb7f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -90,10 +90,7 @@ class PythonMicroBatchStream(
}
private lazy val readInfo: PythonDataSourceReadInfo = {
- ds.source.createReadInfoInPython(
- ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
- outputSchema,
- isStreaming = true)
+ ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming = true)
}
override def createReaderFactory(): PartitionReaderFactory = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index 8ebb91c01fc5c..52af33e7aa993 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -20,16 +20,18 @@ import org.apache.spark.JobArtifactSet
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.streaming.MicroBatchStream
+import org.apache.spark.sql.internal.connector.SupportsMetadata
+import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-
class PythonScan(
- ds: PythonDataSourceV2,
- shortName: String,
- outputSchema: StructType,
- options: CaseInsensitiveStringMap) extends Scan {
-
+ ds: PythonDataSourceV2,
+ shortName: String,
+ outputSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ supportedFilters: Array[Filter]
+) extends Scan with SupportsMetadata {
override def toBatch: Batch = new PythonBatch(ds, shortName, outputSchema, options)
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream =
@@ -44,6 +46,13 @@ class PythonScan(
override def columnarSupportMode(): Scan.ColumnarSupportMode =
Scan.ColumnarSupportMode.UNSUPPORTED
+
+ override def getMetaData(): Map[String, String] = {
+ Map(
+ "PushedFilters" -> supportedFilters.mkString("[", ", ", "]"),
+ "ReadSchema" -> outputSchema.simpleString
+ )
+ }
}
class PythonBatch(
@@ -54,10 +63,7 @@ class PythonBatch(
private val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
private lazy val infoInPython: PythonDataSourceReadInfo = {
- ds.source.createReadInfoInPython(
- ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
- outputSchema,
- isStreaming = false)
+ ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming = false)
}
override def planInputPartitions(): Array[InputPartition] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala
index e30fc9f7978cb..3dabbcb8af05b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala
@@ -16,7 +16,9 @@
*/
package org.apache.spark.sql.execution.datasources.v2.python
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -25,6 +27,35 @@ class PythonScanBuilder(
ds: PythonDataSourceV2,
shortName: String,
outputSchema: StructType,
- options: CaseInsensitiveStringMap) extends ScanBuilder {
- override def build(): Scan = new PythonScan(ds, shortName, outputSchema, options)
+ options: CaseInsensitiveStringMap)
+ extends ScanBuilder
+ with SupportsPushDownFilters {
+ private var supportedFilters: Array[Filter] = Array.empty
+
+ override def build(): Scan =
+ new PythonScan(ds, shortName, outputSchema, options, supportedFilters)
+
+ // Optionally called by DSv2 once to push down filters before the scan is built.
+ override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ if (!SQLConf.get.pythonFilterPushDown) {
+ return filters
+ }
+
+ val dataSource = ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema))
+ ds.source.pushdownFiltersInPython(dataSource, outputSchema, filters) match {
+ case None => filters // No filters are supported.
+ case Some(result) =>
+ // Filter pushdown also returns partitions and the read function.
+ // This helps reduce the number of Python worker calls.
+ ds.setReadInfo(result.readInfo)
+
+ // Partition the filters into supported and unsupported ones.
+ val isPushed = result.isFilterPushed.zip(filters)
+ supportedFilters = isPushed.collect { case (true, filter) => filter }.toArray
+ val unsupported = isPushed.collect { case (false, filter) => filter }.toArray
+ unsupported
+ }
+ }
+
+ override def pushedFilters(): Array[Filter] = supportedFilters
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index b3fd8479bda0d..14aeba92dafe1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -22,10 +22,16 @@ import java.io.{DataInputStream, DataOutputStream}
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
+import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude}
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
import net.razorvine.pickle.Pickler
import org.apache.spark.api.python._
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.PythonUDF
+import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
@@ -34,8 +40,11 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInBatchEvaluatorFactory, PythonPlannerRunner, PythonSQLMetrics}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{BinaryType, DataType, StructType}
+import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.ArrayImplicits._
/**
@@ -61,6 +70,26 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython()
}
+ /**
+ * (Driver-side) Run Python process to push down filters, get the updated
+ * data source instance and the filter pushdown result.
+ */
+ def pushdownFiltersInPython(
+ pythonResult: PythonDataSourceCreationResult,
+ outputSchema: StructType,
+ filters: Array[Filter]): Option[PythonFilterPushdownResult] = {
+ val runner = new UserDefinedPythonDataSourceFilterPushdownRunner(
+ createPythonFunction(pythonResult.dataSource),
+ outputSchema,
+ filters
+ )
+ if (runner.isAnyFilterSupported) {
+ Some(runner.runInPython())
+ } else {
+ None
+ }
+ }
+
/**
* (Driver-side) Run Python process, and get the partition read functions, and
* partition information.
@@ -134,6 +163,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
toAttributes(outputSchema),
Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)),
inputSchema,
+ outputSchema,
conf.arrowMaxRecordsPerBatch,
pythonEvalType,
conf.sessionLocalTimeZone,
@@ -300,49 +330,170 @@ private class UserDefinedPythonDataSourceRunner(
}
}
-case class PythonDataSourceReadInfo(
- func: Array[Byte],
- partitions: Seq[Array[Byte]])
+/**
+ * @param isFilterPushed A sequence of bools indicating whether each filter is pushed down.
+ */
+case class PythonFilterPushdownResult(
+ readInfo: PythonDataSourceReadInfo,
+ isFilterPushed: collection.Seq[Boolean]
+)
/**
- * Send information to a Python process to plan a Python data source read.
+ * Push down filters to a Python data source.
*
- * @param func a Python data source instance
- * @param inputSchema input schema to the data source read from its child plan
- * @param outputSchema output schema of the Python data source
+ * @param dataSource
+ * a Python data source instance
+ * @param schema
+ * output schema of the Python data source
+ * @param filters
+ * all filters to be pushed down
*/
-private class UserDefinedPythonDataSourceReadRunner(
- func: PythonFunction,
- inputSchema: StructType,
- outputSchema: StructType,
- isStreaming: Boolean) extends PythonPlannerRunner[PythonDataSourceReadInfo](func) {
+private class UserDefinedPythonDataSourceFilterPushdownRunner(
+ dataSource: PythonFunction,
+ schema: StructType,
+ filters: collection.Seq[Filter])
+ extends PythonPlannerRunner[PythonFilterPushdownResult](dataSource) {
+
+ private case class SerializedFilter(
+ name: String,
+ columnPath: collection.Seq[String],
+ @JsonInclude(JsonInclude.Include.NON_ABSENT)
+ value: Option[VariantVal],
+ @JsonInclude(JsonInclude.Include.NON_DEFAULT)
+ isNegated: Boolean,
+ @JsonIgnore
+ index: Int
+ )
+
+ private val mapper = new ObjectMapper().registerModules(DefaultScalaModule)
+
+ private def getField(attribute: String): (Seq[String], StructField) = {
+ val columnPath = CatalystSqlParser.parseMultipartIdentifier(attribute)
+ val (_, field) = schema
+ .findNestedField(columnPath, includeCollections = true)
+ .getOrElse(
+ throw QueryCompilationErrors.pythonDataSourceError(
+ action = "plan",
+ tpe = "filter",
+ msg = s"Cannot find field $columnPath in schema"
+ )
+ )
+ (columnPath, field)
+ }
- // See the logic in `pyspark.sql.worker.plan_data_source_read.py`.
- override val workerModule = "pyspark.sql.worker.plan_data_source_read"
+ private val serializedFilters = filters.zipWithIndex.flatMap {
+ case (filter, i) =>
+ // Unwrap Not filter
+ val (childFilter, isNegated) = filter match {
+ case sources.Not(f) => (f, true)
+ case _ => (filter, false)
+ }
+
+ def construct(
+ name: String,
+ attribute: String,
+ value: Option[Any],
+ mapDataType: DataType => DataType = identity): Option[SerializedFilter] = {
+ val (columnPath, field) = getField(attribute)
+ val dataType = mapDataType(field.dataType)
+ val variant = for (v <- value) yield {
+ val catalystValue = CatalystTypeConverters.convertToCatalyst(v)
+ try {
+ VariantExpressionEvalUtils.castToVariant(catalystValue, dataType)
+ } catch {
+ case _: MatchError =>
+ // filter is unsupported if we can't cast it to variant
+ return None
+ }
+ }
+ Some(SerializedFilter(name, columnPath, variant, isNegated, i))
+ }
+
+ childFilter match {
+ case sources.EqualTo(attribute, value) =>
+ construct("EqualTo", attribute, Some(value))
+ case sources.EqualNullSafe(attribute, value) =>
+ construct("EqualNullSafe", attribute, Some(value))
+ case sources.GreaterThan(attribute, value) =>
+ construct("GreaterThan", attribute, Some(value))
+ case sources.GreaterThanOrEqual(attribute, value) =>
+ construct("GreaterThanOrEqual", attribute, Some(value))
+ case sources.LessThan(attribute, value) =>
+ construct("LessThan", attribute, Some(value))
+ case sources.LessThanOrEqual(attribute, value) =>
+ construct("LessThanOrEqual", attribute, Some(value))
+ case sources.In(attribute, value) =>
+ construct("In", attribute, Some(value), ArrayType(_))
+ case sources.IsNull(attribute) =>
+ construct("IsNull", attribute, None)
+ case sources.IsNotNull(attribute) =>
+ construct("IsNotNull", attribute, None)
+ case sources.StringStartsWith(attribute, value) =>
+ construct("StringStartsWith", attribute, Some(value))
+ case sources.StringEndsWith(attribute, value) =>
+ construct("StringEndsWith", attribute, Some(value))
+ case sources.StringContains(attribute, value) =>
+ construct("StringContains", attribute, Some(value))
+ // collation aware filters are currently not supported
+ // And, Or are currently not supported
+ case _ =>
+ None
+ }
+ }
+
+ // See the logic in `pyspark.sql.worker.data_source_pushdown_filters.py`.
+ override val workerModule = "pyspark.sql.worker.data_source_pushdown_filters"
+
+ def isAnyFilterSupported: Boolean = serializedFilters.nonEmpty
override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
// Send Python data source
- PythonWorkerUtils.writePythonFunction(func, dataOut)
-
- // Send input schema
- PythonWorkerUtils.writeUTF(inputSchema.json, dataOut)
+ PythonWorkerUtils.writePythonFunction(dataSource, dataOut)
// Send output schema
- PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+ PythonWorkerUtils.writeUTF(schema.json, dataOut)
+
+ // Send the filters
+ PythonWorkerUtils.writeUTF(mapper.writeValueAsString(serializedFilters), dataOut)
// Send configurations
dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+ }
- dataOut.writeBoolean(isStreaming)
+ override protected def receiveFromPython(dataIn: DataInputStream): PythonFilterPushdownResult = {
+ // Receive the read function and the partitions. Also check for exceptions.
+ val readInfo = PythonDataSourceReadInfo.receive(dataIn)
+
+ // Receive the pushed filters as a list of indices.
+ val numFiltersPushed = dataIn.readInt()
+ val isFilterPushed = ArrayBuffer.fill(filters.length)(false)
+ for (_ <- 0 until numFiltersPushed) {
+ val i = dataIn.readInt()
+ isFilterPushed(serializedFilters(i).index) = true
+ }
+
+ PythonFilterPushdownResult(
+ readInfo = readInfo,
+ isFilterPushed = isFilterPushed
+ )
}
+}
- override protected def receiveFromPython(dataIn: DataInputStream): PythonDataSourceReadInfo = {
+case class PythonDataSourceReadInfo(
+ func: Array[Byte],
+ partitions: Seq[Array[Byte]])
+
+object PythonDataSourceReadInfo {
+ def receive(dataIn: DataInputStream): PythonDataSourceReadInfo = {
// Receive the picked reader or an exception raised in Python worker.
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.pythonDataSourceError(
- action = "initialize", tpe = "reader", msg = msg)
+ action = "initialize",
+ tpe = "reader",
+ msg = msg
+ )
}
// Receive the pickled 'read' function.
@@ -354,16 +505,55 @@ private class UserDefinedPythonDataSourceReadRunner(
if (numPartitions == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.pythonDataSourceError(
- action = "generate", tpe = "read partitions", msg = msg)
+ action = "generate",
+ tpe = "read partitions",
+ msg = msg
+ )
}
for (_ <- 0 until numPartitions) {
val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
pickledPartitions.append(pickledPartition)
}
- PythonDataSourceReadInfo(
- func = pickledFunction,
- partitions = pickledPartitions.toSeq)
+ PythonDataSourceReadInfo(func = pickledFunction, partitions = pickledPartitions.toSeq)
+ }
+}
+
+/**
+ * Send information to a Python process to plan a Python data source read.
+ *
+ * @param func a Python data source instance
+ * @param inputSchema input schema to the data source read from its child plan
+ * @param outputSchema output schema of the Python data source
+ */
+private class UserDefinedPythonDataSourceReadRunner(
+ func: PythonFunction,
+ inputSchema: StructType,
+ outputSchema: StructType,
+ isStreaming: Boolean) extends PythonPlannerRunner[PythonDataSourceReadInfo](func) {
+
+ // See the logic in `pyspark.sql.worker.plan_data_source_read.py`.
+ override val workerModule = "pyspark.sql.worker.plan_data_source_read"
+
+ override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
+ // Send Python data source
+ PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+ // Send input schema
+ PythonWorkerUtils.writeUTF(inputSchema.json, dataOut)
+
+ // Send output schema
+ PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+
+ // Send configurations
+ dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+ dataOut.writeBoolean(SQLConf.get.pythonFilterPushDown)
+
+ dataOut.writeBoolean(isStreaming)
+ }
+
+ override protected def receiveFromPython(dataIn: DataInputStream): PythonDataSourceReadInfo = {
+ PythonDataSourceReadInfo.receive(dataIn)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 649c23b390e01..28de012402220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -55,7 +55,11 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
// Seq of operator names who uses state schema v3 and TWS related options.
// This Seq was used in checks before reading state schema files.
- private val twsShortNameSeq = Seq("transformWithStateExec", "transformWithStateInPandasExec")
+ private val twsShortNameSeq = Seq(
+ "transformWithStateExec",
+ "transformWithStateInPandasExec",
+ "transformWithStateInPySparkExec"
+ )
override def shortName(): String = "statestore"
@@ -88,6 +92,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
try {
+ // SPARK-51779 TODO: Support stream-stream joins with virtual column families
val (keySchema, valueSchema) = sourceOptions.joinSide match {
case JoinSideValues.left =>
StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
index b9adb379e38c1..e1d61de77380f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.Jo
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{JoinSide, LeftSide, RightSide}
-import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, SymmetricHashJoinStateManager}
+import org.apache.spark.sql.execution.streaming.state.{JoinStateManagerStoreGenerator, StateStoreConf, SymmetricHashJoinStateManager}
import org.apache.spark.sql.types.{BooleanType, StructType}
import org.apache.spark.util.SerializableConfiguration
@@ -111,7 +111,7 @@ class StreamStreamJoinStatePartitionReader(
partition.sourceOptions.stateCheckpointLocation.toString,
partition.queryId, partition.sourceOptions.operatorId,
partition.sourceOptions.batchId + 1, -1, None)
- joinStateManager = new SymmetricHashJoinStateManager(
+ joinStateManager = SymmetricHashJoinStateManager(
joinSide,
inputAttributes,
joinKeys = DataTypeUtils.toAttributes(keySchema),
@@ -125,7 +125,8 @@ class StreamStreamJoinStatePartitionReader(
skippedNullValueCount = None,
useStateStoreCoordinator = false,
snapshotStartVersion =
- partition.sourceOptions.fromSnapshotOptions.map(_.snapshotStartBatchId + 1)
+ partition.sourceOptions.fromSnapshotOptions.map(_.snapshotStartBatchId + 1),
+ joinStoreGenerator = new JoinStateManagerStoreGenerator()
)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
index fafde89001aa2..23bca35725397 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType, VariantType}
import org.apache.spark.util.Utils
/**
@@ -67,10 +67,14 @@ abstract class XmlDataSource extends Serializable with Logging {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: XmlOptions): Option[StructType] = {
- if (inputPaths.nonEmpty) {
- Some(infer(sparkSession, inputPaths, parsedOptions))
- } else {
- None
+ parsedOptions.singleVariantColumn match {
+ case Some(columnName) => Some(StructType(Array(StructField(columnName, VariantType))))
+ case None =>
+ if (inputPaths.nonEmpty) {
+ Some(infer(sparkSession, inputPaths, parsedOptions))
+ } else {
+ None
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
index a7661d8dbf8e9..e5004c499a070 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.SerializableConfiguration
/**
* Provides access to XML data from pure SQL statements.
*/
-class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister {
+case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister {
override def shortName(): String = "xml"
@@ -132,13 +132,10 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def toString: String = "XML"
- override def hashCode(): Int = getClass.hashCode()
-
- override def equals(other: Any): Boolean = other.isInstanceOf[XmlFileFormat]
-
override def supportDataType(dataType: DataType): Boolean = dataType match {
- case _: VariantType => false
+ case _: VariantType => true
+ case _: TimeType => false
case _: AtomicType => true
case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
index 5f5a9e188532e..059729d86bfaf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
@@ -55,8 +55,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp
plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
case DynamicPruningSubquery(
value, buildPlan, buildKeys, broadcastKeyIndices, onlyInBroadcast, exprId, _) =>
- val sparkPlan = QueryExecution.createSparkPlan(
- sparkSession, sparkSession.sessionState.planner, buildPlan)
+ val sparkPlan = QueryExecution.createSparkPlan(sparkSession.sessionState.planner, buildPlan)
// Using `sparkPlan` is a little hacky as it is based on the assumption that this rule is
// the first to be applied (apart from `InsertAdaptiveSparkPlan`).
val canReuseExchange = conf.exchangeReuseEnabled && buildKeys.nonEmpty &&
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 2565a14cef90b..c70ee637a2489 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -125,7 +125,6 @@ trait BroadcastExchangeLike extends Exchange {
case class BroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan) extends BroadcastExchangeLike {
- import BroadcastExchangeExec._
override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
@@ -203,9 +202,10 @@ case class BroadcastExchangeExec(
}
longMetric("dataSize") += dataSize
- if (dataSize >= MAX_BROADCAST_TABLE_BYTES) {
+ val maxBroadcastTableSizeInBytes = conf.maxBroadcastTableSizeInBytes
+ if (dataSize >= maxBroadcastTableSizeInBytes) {
throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError(
- MAX_BROADCAST_TABLE_BYTES, dataSize)
+ maxBroadcastTableSizeInBytes, dataSize)
}
val beforeBroadcast = System.nanoTime()
@@ -268,8 +268,6 @@ case class BroadcastExchangeExec(
}
object BroadcastExchangeExec {
- val MAX_BROADCAST_TABLE_BYTES = 8L << 30
-
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange",
SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index ce7d48babc91e..a1abb64e262df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -25,11 +25,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator}
import org.apache.spark.sql.execution.metric.SQLMetric
-import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType}
+import org.apache.spark.sql.types.{BooleanType, DataType, IntegralType, LongType}
/**
* @param relationTerm variable name for HashedRelation
@@ -111,7 +110,7 @@ trait HashJoin extends JoinCodegenSupport {
require(leftKeys.length == rightKeys.length &&
leftKeys.map(_.dataType)
.zip(rightKeys.map(_.dataType))
- .forall(types => DataTypeUtils.sameType(types._1, types._2)),
+ .forall(types => DataType.equalsStructurally(types._1, types._2, ignoreNullability = true)),
"Join keys from two sides should have same length and types")
buildSide match {
case BuildLeft => (leftKeys, rightKeys)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 6bd49e75af241..ca7836992aacb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -36,6 +36,8 @@ import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.python.BatchIterator
import org.apache.spark.sql.execution.r.ArrowRRunner
import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.internal.SQLConf
@@ -218,13 +220,17 @@ case class MapPartitionsInRWithArrowExec(
child: SparkPlan) extends UnaryExecNode {
override def producedAttributes: AttributeSet = AttributeSet(output)
+ private val batchSize = conf.arrowMaxRecordsPerBatch
+
override def outputPartitioning: Partitioning = child.outputPartitioning
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { inputIter =>
val outputTypes = schema.map(_.dataType)
- val batchIter = Iterator(inputIter)
+ // DO NOT use iter.grouped(). See BatchIterator.
+ val batchIter =
+ if (batchSize > 0) new BatchIterator(inputIter, batchSize) else Iterator(inputIter)
val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_DAPPLY)
@@ -247,8 +253,10 @@ case class MapPartitionsInRWithArrowExec(
val outputProject = UnsafeProjection.create(output, output)
columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
- assert(outputTypes == actualDataTypes, "Invalid schema from dapply(): " +
- s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
+ if (outputTypes != actualDataTypes) {
+ throw QueryExecutionErrors.arrowDataTypeMismatchError(
+ "dapply()", outputTypes, actualDataTypes)
+ }
batch.rowIterator.asScala
}.map(outputProject)
}
@@ -593,8 +601,10 @@ case class FlatMapGroupsInRWithArrowExec(
columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
- assert(outputTypes == actualDataTypes, "Invalid schema from gapply(): " +
- s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
+ if (outputTypes != actualDataTypes) {
+ throw QueryExecutionErrors.arrowDataTypeMismatchError(
+ "gapply()", outputTypes, actualDataTypes)
+ }
batch.rowIterator().asScala
}.map(outputProject)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 5bf84b22a792b..9ec454731e4a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -23,10 +23,11 @@ import org.apache.spark.{JobArtifactSet, TaskContext}
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructType, UserDefinedType}
/**
* Grouped a iterator into batches.
@@ -106,7 +107,9 @@ class ArrowEvalPythonEvaluatorFactory(
schema: StructType,
context: TaskContext): Iterator[InternalRow] = {
- val outputTypes = output.drop(childOutput.length).map(_.dataType)
+ val outputTypes = output.drop(childOutput.length).map(_.dataType.transformRecursively {
+ case udt: UserDefinedType[_] => udt.sqlType
+ })
val batchIter = Iterator(iter)
@@ -125,8 +128,10 @@ class ArrowEvalPythonEvaluatorFactory(
columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
- assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
- s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
+ if (outputTypes != actualDataTypes) {
+ throw QueryExecutionErrors.arrowDataTypeMismatchError(
+ "pandas_udf()", outputTypes, actualDataTypes)
+ }
batch.rowIterator.asScala
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 9e210bf5241bb..d7106403a3880 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.{JobArtifactSet, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.StructType
@@ -81,8 +82,10 @@ case class ArrowEvalPythonUDTFExec(
val actualDataTypes = (0 until flattenedBatch.numCols()).map(
i => flattenedBatch.column(i).dataType())
- assert(outputTypes == actualDataTypes, "Invalid schema from arrow-enabled Python UDTF: " +
- s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
+ if (outputTypes != actualDataTypes) {
+ throw QueryExecutionErrors.arrowDataTypeMismatchError(
+ "Python UDTF", outputTypes, actualDataTypes)
+ }
flattenedBatch.setNumRows(batch.numRows())
flattenedBatch.rowIterator().asScala
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index d7f73c648ac29..9956e4ce0f3b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -107,8 +107,12 @@ class ArrowPythonWithNamedArgumentRunner(
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf,
pythonMetrics, jobArtifactUUID) {
- override protected def writeUDF(dataOut: DataOutputStream): Unit =
+ override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+ if (evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) {
+ PythonWorkerUtils.writeUTF(schema.json, dataOut)
+ }
PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas, profiler)
+ }
}
object ArrowPythonRunner {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 99bcbfd9eb246..6a81eaffd20a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -173,7 +173,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
- case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
+ case Seq(u: PythonUDF) => correctEvalType(e) == correctEvalType(u) && canEvaluateInPython(u)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasScalarPythonUDF)
}
@@ -197,10 +197,10 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
&& firstVisitedScalarUDFEvalType.isEmpty =>
- firstVisitedScalarUDFEvalType = Some(udf.evalType)
+ firstVisitedScalarUDFEvalType = Some(correctEvalType(udf))
Seq(udf)
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
- && canChainUDF(udf.evalType) =>
+ && canChainUDF(correctEvalType(udf)) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDFs)
}
@@ -235,6 +235,19 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
}
}
+ private def correctEvalType(udf: PythonUDF): Int = {
+ if (udf.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) {
+ if (conf.pythonUDFArrowFallbackOnUDT &&
+ (containsUDT(udf.dataType) || udf.children.exists(expr => containsUDT(expr.dataType)))) {
+ PythonEvalType.SQL_BATCHED_UDF
+ } else {
+ PythonEvalType.SQL_ARROW_BATCHED_UDF
+ }
+ } else {
+ udf.evalType
+ }
+ }
+
private def containsUDT(dataType: DataType): Boolean = dataType match {
case _: UserDefinedType[_] => true
case ArrayType(elementType, _) => containsUDT(elementType)
@@ -272,33 +285,25 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
AttributeReference(s"pythonUDF$i", u.dataType)()
}
- val evalTypes = validUdfs.map(_.evalType).toSet
+ val evalTypes = validUdfs.map(correctEvalType).toSet
if (evalTypes.size != 1) {
throw SparkException.internalError(
"Expected udfs have the same evalType but got different evalTypes: " +
evalTypes.mkString(","))
}
val evalType = evalTypes.head
-
- val hasUDTInput = validUdfs.exists(_.children.exists(expr => containsUDT(expr.dataType)))
- val hasUDTReturn = validUdfs.exists(udf => containsUDT(udf.dataType))
-
val evaluation = evalType match {
case PythonEvalType.SQL_BATCHED_UDF =>
- BatchEvalPython(validUdfs, resultAttrs, child)
- case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF =>
- ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
- case PythonEvalType.SQL_ARROW_BATCHED_UDF =>
-
- if (hasUDTInput || hasUDTReturn) {
+ if (validUdfs.exists(_.evalType != PythonEvalType.SQL_BATCHED_UDF)) {
// Use BatchEvalPython if UDT is detected
logWarning(log"Arrow optimization disabled due to " +
log"${MDC(REASON, "UDT input or return type")}. " +
log"Falling back to non-Arrow-optimized UDF execution.")
- BatchEvalPython(validUdfs, resultAttrs, child)
- } else {
- ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
}
+ BatchEvalPython(validUdfs, resultAttrs, child)
+ case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
+ | PythonEvalType.SQL_ARROW_BATCHED_UDF =>
+ ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
case _ =>
throw SparkException.internalError("Unexpected UDF evalType")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 3b7b2c56744a8..9e3e8610ed375 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -20,17 +20,20 @@ package org.apache.spark.sql.execution.python
import scala.jdk.CollectionConverters._
import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, TaskContext}
-import org.apache.spark.api.python.ChainedPythonFunctions
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.SQLMetric
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
class MapInBatchEvaluatorFactory(
output: Seq[Attribute],
chainedFunc: Seq[(ChainedPythonFunctions, Long)],
- outputTypes: StructType,
+ inputSchema: StructType,
+ outputSchema: DataType,
batchSize: Int,
pythonEvalType: Int,
sessionLocalTimeZone: String,
@@ -63,18 +66,30 @@ class MapInBatchEvaluatorFactory(
chainedFunc,
pythonEvalType,
argOffsets,
- StructType(Array(StructField("struct", outputTypes))),
+ StructType(Array(StructField("struct", inputSchema))),
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
- None)
+ None) with BatchedPythonArrowInput
val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context)
val unsafeProj = UnsafeProjection.create(output, output)
columnarBatchIter.flatMap { batch =>
+ if (SQLConf.get.pysparkArrowValidateSchema) {
+ // Ensure the schema matches the expected schema, but allowing nullable fields in the
+ // output schema to become non-nullable in the actual schema.
+ val actualSchema = batch.column(0).dataType()
+ val isCompatible =
+ DataType.equalsIgnoreCompatibleNullability(from = actualSchema, to = outputSchema)
+ if (!isCompatible) {
+ throw QueryExecutionErrors.arrowDataTypeMismatchError(
+ PythonEvalType.toString(pythonEvalType), Seq(outputSchema), Seq(actualSchema))
+ }
+ }
+
// Scalar Iterator UDF returns a StructType column in ColumnarBatch, select
// the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 096e9d7d16420..c003d503c7caf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -56,6 +56,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
output,
chainedFunc,
child.schema,
+ pythonUDF.dataType,
conf.arrowMaxRecordsPerBatch,
pythonEvalType,
conf.sessionLocalTimeZone,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index adbfa341f3d5d..05fd571e0265b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -145,7 +145,10 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput {
self: BasePythonRunner[Iterator[InternalRow], _] =>
- private val arrowMaxRecordsPerBatch = SQLConf.get.arrowMaxRecordsPerBatch
+ private val arrowMaxRecordsPerBatch = {
+ val v = SQLConf.get.arrowMaxRecordsPerBatch
+ if (v > 0) v else Int.MaxValue
+ }
private val maxBytesPerBatch = SQLConf.get.arrowMaxBytesPerBatch
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index e7d4aa9f04607..9511b4c3f305c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, PythonWorker, SpecialLengths}
import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
@@ -43,6 +44,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT
+ protected def arrowMaxRecordsPerOutputBatch: Int = SQLConf.get.arrowMaxRecordsPerOutputBatch
+
protected def newReaderIterator(
stream: DataInputStream,
writer: Writer,
@@ -62,7 +65,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
private var reader: ArrowStreamReader = _
private var root: VectorSchemaRoot = _
private var schema: StructType = _
- private var vectors: Array[ColumnVector] = _
+ private var processor: ArrowOutputProcessor = _
context.addTaskCompletionListener[Unit] { _ =>
if (reader != null) {
@@ -84,17 +87,12 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
}
try {
if (reader != null && batchLoaded) {
- val bytesReadStart = reader.bytesRead()
- batchLoaded = reader.loadNextBatch()
+ batchLoaded = processor.loadBatch()
if (batchLoaded) {
- val batch = new ColumnarBatch(vectors)
- val rowCount = root.getRowCount
- batch.setNumRows(root.getRowCount)
- val bytesReadEnd = reader.bytesRead()
- pythonMetrics("pythonNumRowsReceived") += rowCount
- pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart
+ val batch = processor.produceBatch()
deserializeColumnarBatch(batch, schema)
} else {
+ processor.close()
reader.close(false)
allocator.close()
// Reach end of stream. Call `read()` again to read control data.
@@ -106,9 +104,14 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
reader = new ArrowStreamReader(stream, allocator)
root = reader.getVectorSchemaRoot()
schema = ArrowUtils.fromArrowSchema(root.getSchema())
- vectors = root.getFieldVectors().asScala.map { vector =>
- new ArrowColumnVector(vector)
- }.toArray[ColumnVector]
+
+ if (arrowMaxRecordsPerOutputBatch > 0) {
+ processor = new SliceArrowOutputProcessorImpl(
+ reader, pythonMetrics, arrowMaxRecordsPerOutputBatch)
+ } else {
+ processor = new ArrowOutputProcessorImpl(reader, pythonMetrics)
+ }
+
read()
case SpecialLengths.TIMING_DATA =>
handleTimingData()
@@ -133,3 +136,114 @@ private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarB
batch: ColumnarBatch,
schema: StructType): ColumnarBatch = batch
}
+
+trait ArrowOutputProcessor {
+ def loadBatch(): Boolean
+ protected def getRoot: VectorSchemaRoot
+ protected def getVectors(root: VectorSchemaRoot): Array[ColumnVector]
+ def produceBatch(): ColumnarBatch
+ def close(): Unit
+}
+
+class ArrowOutputProcessorImpl(reader: ArrowStreamReader, pythonMetrics: Map[String, SQLMetric])
+ extends ArrowOutputProcessor {
+ protected val root = reader.getVectorSchemaRoot()
+ protected val schema: StructType = ArrowUtils.fromArrowSchema(root.getSchema())
+ private val vectors: Array[ColumnVector] = root.getFieldVectors().asScala.map { vector =>
+ new ArrowColumnVector(vector)
+ }.toArray[ColumnVector]
+
+ protected var rowCount = -1
+
+ override def loadBatch(): Boolean = {
+ val bytesReadStart = reader.bytesRead()
+ val batchLoaded = reader.loadNextBatch()
+ if (batchLoaded) {
+ rowCount = root.getRowCount
+ val bytesReadEnd = reader.bytesRead()
+ pythonMetrics("pythonNumRowsReceived") += rowCount
+ pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart
+ }
+ batchLoaded
+ }
+
+ protected override def getRoot: VectorSchemaRoot = root
+ protected override def getVectors(root: VectorSchemaRoot): Array[ColumnVector] = vectors
+ override def produceBatch(): ColumnarBatch = {
+ val batchRoot = getRoot
+ val vectors = getVectors(batchRoot)
+ val batch = new ColumnarBatch(vectors)
+ batch.setNumRows(batchRoot.getRowCount)
+ batch
+ }
+ override def close(): Unit = {
+ vectors.foreach(_.close())
+ root.close()
+ }
+}
+
+class SliceArrowOutputProcessorImpl(
+ reader: ArrowStreamReader,
+ pythonMetrics: Map[String, SQLMetric],
+ arrowMaxRecordsPerOutputBatch: Int)
+ extends ArrowOutputProcessorImpl(reader, pythonMetrics) {
+
+ private var currentRowIdx = -1
+ private var prevRoot: VectorSchemaRoot = null
+ private var prevVectors: Array[ColumnVector] = _
+
+ override def produceBatch(): ColumnarBatch = {
+ val batchRoot = getRoot
+
+ if (batchRoot != prevRoot) {
+ if (prevRoot != null) {
+ prevVectors.foreach(_.close())
+ prevRoot.close()
+ }
+ prevRoot = batchRoot
+ }
+
+ val vectors = getVectors(batchRoot)
+ prevVectors = vectors
+
+ val batch = new ColumnarBatch(vectors)
+ batch.setNumRows(batchRoot.getRowCount)
+ batch
+ }
+
+ override def loadBatch(): Boolean = {
+ if (rowCount > 0 && currentRowIdx < rowCount) {
+ true
+ } else {
+ val loaded = super.loadBatch()
+ currentRowIdx = 0
+ loaded
+ }
+ }
+
+ protected override def getRoot: VectorSchemaRoot = {
+ val remainingRows = rowCount - currentRowIdx
+ val rootSlice = if (remainingRows > arrowMaxRecordsPerOutputBatch) {
+ root.slice(currentRowIdx, arrowMaxRecordsPerOutputBatch)
+ } else {
+ root
+ }
+
+ currentRowIdx = currentRowIdx + rootSlice.getRowCount
+
+ rootSlice
+ }
+
+ protected override def getVectors(root: VectorSchemaRoot): Array[ColumnVector] = {
+ root.getFieldVectors.asScala.map { vector =>
+ new ArrowColumnVector(vector)
+ }.toArray[ColumnVector]
+ }
+
+ override def close(): Unit = {
+ if (prevRoot != null) {
+ prevVectors.foreach(_.close())
+ prevRoot.close()
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
index 89273b7bc80f0..3979220618baa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
@@ -19,6 +19,7 @@
package org.apache.spark.sql.execution.python.streaming
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream}
+import java.nio.channels.Channels
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
@@ -99,7 +100,7 @@ class PythonStreamingSourceRunner(
pythonWorkerFactory = Some(workerFactory)
val stream = new BufferedOutputStream(
- pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+ Channels.newOutputStream(pythonWorker.get.channel), bufferSize)
dataOut = new DataOutputStream(stream)
PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
@@ -118,7 +119,7 @@ class PythonStreamingSourceRunner(
dataOut.flush()
dataIn = new DataInputStream(
- new BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, bufferSize))
+ new BufferedInputStream(Channels.newInputStream(pythonWorker.get.channel), bufferSize))
val initStatus = dataIn.readInt()
if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkDeserializer.scala
similarity index 72%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkDeserializer.scala
index 1a8ffb35c0533..25f84f0be4c21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkDeserializer.scala
@@ -26,18 +26,19 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
/**
* A helper class to deserialize state Arrow batches from the state socket in
- * TransformWithStateInPandas.
+ * TransformWithStateInPySpark.
*/
-class TransformWithStateInPandasDeserializer(deserializer: ExpressionEncoder.Deserializer[Row])
+class TransformWithStateInPySparkDeserializer(deserializer: ExpressionEncoder.Deserializer[Row])
extends Logging {
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
- s"stdin reader for transformWithStateInPandas state socket", 0, Long.MaxValue)
+ s"stdin reader for transformWithStateInPySpark state socket", 0, Long.MaxValue)
/**
* Read Arrow batches from the given stream and deserialize them into rows.
@@ -57,4 +58,24 @@ class TransformWithStateInPandasDeserializer(deserializer: ExpressionEncoder.Des
reader.close(false)
rows.toSeq
}
+
+ def readListElements(stream: DataInputStream, listStateInfo: ListStateInfo): Seq[Row] = {
+ val rows = new scala.collection.mutable.ArrayBuffer[Row]
+
+ var endOfLoop = false
+ while (!endOfLoop) {
+ val size = stream.readInt()
+ if (size < 0) {
+ endOfLoop = true
+ } else {
+ val bytes = new Array[Byte](size)
+ stream.read(bytes, 0, size)
+ val newRow = PythonSQLUtils.toJVMRow(bytes, listStateInfo.schema,
+ listStateInfo.deserializer)
+ rows.append(newRow)
+ }
+ }
+
+ rows.toSeq
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
similarity index 90%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index e77035e31ccb6..d6e2a6f844ea0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, PythonUDF, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.ProcessingTime
+import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
@@ -46,7 +47,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti
/**
* Physical operator for executing
- * [[org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPandas]]
+ * [[org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark]]
*
* @param functionExpr function called on each group
* @param groupingAttributes used to group the data
@@ -57,6 +58,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
+ * @param userFacingDataType the user facing data type of the function (both param and return type)
* @param child the physical plan for the underlying data
* @param isStreaming defines whether the query is streaming or batch
* @param hasInitialState defines whether the query has initial state
@@ -64,7 +66,7 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Uti
* @param initialStateGroupingAttrs grouping attributes for initial state
* @param initialStateSchema schema for initial state
*/
-case class TransformWithStateInPandasExec(
+case class TransformWithStateInPySparkExec(
functionExpr: Expression,
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
@@ -74,6 +76,7 @@ case class TransformWithStateInPandasExec(
batchTimestampMs: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
+ userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value,
child: SparkPlan,
isStreaming: Boolean = true,
hasInitialState: Boolean,
@@ -85,7 +88,15 @@ case class TransformWithStateInPandasExec(
with WatermarkSupport
with TransformWithStateMetadataUtils {
- override def shortName: String = "transformWithStateInPandasExec"
+ // NOTE: This is needed to comply with existing release of transformWithStateInPandas.
+ override def shortName: String = if (
+ userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS
+ ) {
+ "transformWithStateInPandasExec"
+ } else {
+ "transformWithStateInPySparkExec"
+ }
+
private val pythonUDF = functionExpr.asInstanceOf[PythonUDF]
private val pythonFunction = pythonUDF.func
private val chainedFunc =
@@ -116,8 +127,7 @@ case class TransformWithStateInPandasExec(
override def operatorStateMetadataVersion: Int = 2
override def getColFamilySchemas(
- shouldBeNullable: Boolean
- ): Map[String, StateStoreColFamilySchema] = {
+ shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = {
// For Python, the user can explicitly set nullability on schema, so
// we need to throw an error if the schema is nullable
driverProcessorHandle.getColumnFamilySchemas(
@@ -172,7 +182,7 @@ case class TransformWithStateInPandasExec(
batchId: Long,
stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
// Start a python runner on driver, and execute pre-init UDF on the runner
- val runner = new TransformWithStateInPandasPythonPreInitRunner(
+ val runner = new TransformWithStateInPySparkPythonPreInitRunner(
pythonFunction,
"pyspark.sql.streaming.transform_with_state_driver_worker",
sessionLocalTimeZone,
@@ -186,12 +196,14 @@ case class TransformWithStateInPandasExec(
runner.process()
} catch {
case e: Throwable =>
- throw new SparkException("TransformWithStateInPandas driver worker " +
+ throw new SparkException("TransformWithStateInPySpark driver worker " +
"exited unexpectedly (crashed)", e)
}
runner.stop()
+ val info = getStateInfo
+ val stateSchemaDir = stateSchemaDirPath()
- validateAndWriteStateSchema(hadoopConf, batchId, stateSchemaVersion, getStateInfo,
+ validateAndWriteStateSchema(hadoopConf, batchId, stateSchemaVersion, info, stateSchemaDir,
session, operatorStateMetadataVersion, stateStoreEncodingFormat =
conf.stateStoreEncodingFormat)
}
@@ -330,7 +342,7 @@ case class TransformWithStateInPandasExec(
private def initNewStateStoreAndProcessData(
partitionId: Int,
hadoopConfBroadcast: Broadcast[SerializableConfiguration])
- (f: StateStore => Iterator[InternalRow]): Iterator[InternalRow] = {
+ (f: StateStore => Iterator[InternalRow]): Iterator[InternalRow] = {
val providerId = {
val tempDirPath = Utils.createTempDir().getAbsolutePath
@@ -380,7 +392,7 @@ case class TransformWithStateInPandasExec(
// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForDataForLateEvents match {
- case Some(predicate) =>
+ case Some(predicate) if timeMode == TimeMode.EventTime() =>
applyRemovingRowsOlderThanWatermark(dataIterator, predicate)
case _ =>
dataIterator
@@ -391,10 +403,18 @@ case class TransformWithStateInPandasExec(
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
+ val evalType = {
+ if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
+ } else {
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF
+ }
+ }
+
val outputIterator = if (!hasInitialState) {
- val runner = new TransformWithStateInPandasPythonRunner(
+ val runner = new TransformWithStateInPySparkPythonRunner(
chainedFunc,
- PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
+ evalType,
Array(argOffsets),
DataTypeUtils.fromAttributes(dedupAttributes),
processorHandle,
@@ -418,9 +438,17 @@ case class TransformWithStateInPandasExec(
val groupedData: Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] =
new CoGroupedIterator(data, initData, groupingAttributes)
- val runner = new TransformWithStateInPandasPythonInitialStateRunner(
+ val evalType = {
+ if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
+ } else {
+ PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF
+ }
+ }
+
+ val runner = new TransformWithStateInPySparkPythonInitialStateRunner(
chainedFunc,
- PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
+ evalType,
Array(argOffsets ++ initArgOffsets),
DataTypeUtils.fromAttributes(dedupAttributes),
DataTypeUtils.fromAttributes(initDedupAttributes),
@@ -457,7 +485,7 @@ case class TransformWithStateInPandasExec(
}
override protected def withNewChildrenInternal(
- newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateInPandasExec =
+ newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateInPySparkExec =
if (hasInitialState) {
copy(child = newLeft, initialState = newRight)
} else {
@@ -470,15 +498,16 @@ case class TransformWithStateInPandasExec(
}
// scalastyle:off argcount
-object TransformWithStateInPandasExec {
+object TransformWithStateInPySparkExec {
- // Plan logical transformWithStateInPandas for batch queries
+ // Plan logical transformWithStateInPySpark for batch queries
def generateSparkPlanForBatchQueries(
functionExpr: Expression,
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
outputMode: OutputMode,
timeMode: TimeMode,
+ userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value,
child: SparkPlan,
hasInitialState: Boolean = false,
initialState: SparkPlan,
@@ -494,7 +523,7 @@ object TransformWithStateInPandasExec {
stateStoreCkptIds = None
)
- new TransformWithStateInPandasExec(
+ new TransformWithStateInPySparkExec(
functionExpr,
groupingAttributes,
output,
@@ -504,6 +533,7 @@ object TransformWithStateInPandasExec {
Some(System.currentTimeMillis),
None,
None,
+ userFacingDataType,
child,
isStreaming = false,
hasInitialState,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
similarity index 75%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 9b2a2518a7b2f..dffdaca1b835e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -17,21 +17,24 @@
package org.apache.spark.sql.execution.python.streaming
-import java.io.{DataInputStream, DataOutputStream}
-import java.net.ServerSocket
+import java.io.{DataInputStream, DataOutputStream, File}
+import java.net.{InetAddress, InetSocketAddress, StandardProtocolFamily, UnixDomainSocketAddress}
+import java.nio.channels.ServerSocketChannel
+import java.util.UUID
import scala.concurrent.ExecutionContext
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonFunction, PythonRDD, PythonWorkerUtils, StreamingPythonRunner}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR, PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.{BasicPythonArrowOutput, PythonArrowInput, PythonUDFRunner}
-import org.apache.spark.sql.execution.python.streaming.TransformWithStateInPandasPythonRunner.{GroupedInType, InType}
+import org.apache.spark.sql.execution.python.streaming.TransformWithStateInPySparkPythonRunner.{GroupedInType, InType}
import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulProcessorHandleImpl}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
@@ -39,10 +42,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils
/**
- * Python runner with no initial state in TransformWithStateInPandas.
+ * Python runner with no initial state in TransformWithStateInPySpark.
* Write input data as one single InternalRow in each row in arrow batch.
*/
-class TransformWithStateInPandasPythonRunner(
+class TransformWithStateInPySparkPythonRunner(
funcs: Seq[(ChainedPythonFunctions, Long)],
evalType: Int,
argOffsets: Array[Array[Int]],
@@ -55,7 +58,7 @@ class TransformWithStateInPandasPythonRunner(
groupingKeySchema: StructType,
batchTimestampMs: Option[Long],
eventTimeWatermarkForEviction: Option[Long])
- extends TransformWithStateInPandasPythonBaseRunner[InType](
+ extends TransformWithStateInPySparkPythonBaseRunner[InType](
funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
@@ -93,10 +96,10 @@ class TransformWithStateInPandasPythonRunner(
}
/**
- * Python runner with initial state in TransformWithStateInPandas.
+ * Python runner with initial state in TransformWithStateInPySpark.
* Write input data as one InternalRow(inputRow, initialState) in each row in arrow batch.
*/
-class TransformWithStateInPandasPythonInitialStateRunner(
+class TransformWithStateInPySparkPythonInitialStateRunner(
funcs: Seq[(ChainedPythonFunctions, Long)],
evalType: Int,
argOffsets: Array[Array[Int]],
@@ -110,7 +113,7 @@ class TransformWithStateInPandasPythonInitialStateRunner(
groupingKeySchema: StructType,
batchTimestampMs: Option[Long],
eventTimeWatermarkForEviction: Option[Long])
- extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+ extends TransformWithStateInPySparkPythonBaseRunner[GroupedInType](
funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
@@ -159,9 +162,9 @@ class TransformWithStateInPandasPythonInitialStateRunner(
}
/**
- * Base Python runner implementation for TransformWithStateInPandas.
+ * Base Python runner implementation for TransformWithStateInPySpark.
*/
-abstract class TransformWithStateInPandasPythonBaseRunner[I](
+abstract class TransformWithStateInPySparkPythonBaseRunner[I](
funcs: Seq[(ChainedPythonFunctions, Long)],
evalType: Int,
argOffsets: Array[Array[Int]],
@@ -178,7 +181,7 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics)
with PythonArrowInput[I]
with BasicPythonArrowOutput
- with TransformWithStateInPandasPythonRunnerUtils
+ with TransformWithStateInPySparkPythonRunnerUtils
with Logging {
protected val sqlConf = SQLConf.get
@@ -196,8 +199,13 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
super.handleMetadataBeforeExec(stream)
- // Also write the port number for state server
- stream.writeInt(stateServerSocketPort)
+ // Also write the port/path number for state server
+ if (isUnixDomainSock) {
+ stream.writeInt(-1)
+ PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)
+ } else {
+ stream.writeInt(stateServerSocketPort)
+ }
PythonRDD.writeUTF(groupingKeySchema.json, stream)
}
@@ -211,9 +219,9 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
val executionContext = ExecutionContext.fromExecutor(executor)
executionContext.execute(
- new TransformWithStateInPandasStateServer(stateServerSocket, processorHandle,
+ new TransformWithStateInPySparkStateServer(stateServerSocket, processorHandle,
groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes,
- sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch,
+ sqlConf.arrowTransformWithStateInPySparkMaxStateRecordsPerBatch,
batchTimestampMs, eventTimeWatermarkForEviction))
context.addTaskCompletionListener[Unit] { _ =>
@@ -231,17 +239,17 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
}
/**
- * TransformWithStateInPandas driver side Python runner. Similar as executor side runner,
+ * TransformWithStateInPySpark driver side Python runner. Similar as executor side runner,
* will start a new daemon thread on the Python runner to run state server.
*/
-class TransformWithStateInPandasPythonPreInitRunner(
+class TransformWithStateInPySparkPythonPreInitRunner(
func: PythonFunction,
workerModule: String,
timeZoneId: String,
groupingKeySchema: StructType,
processorHandleImpl: DriverStatefulProcessorHandleImpl)
extends StreamingPythonRunner(func, "", "", workerModule)
- with TransformWithStateInPandasPythonRunnerUtils
+ with TransformWithStateInPySparkPythonRunnerUtils
with Logging {
protected val sqlConf = SQLConf.get
@@ -255,14 +263,19 @@ class TransformWithStateInPandasPythonPreInitRunner(
dataOut = result._1
dataIn = result._2
- // start state server, update socket port
+ // start state server, update socket port/path
startStateServer()
(dataOut, dataIn)
}
def process(): Unit = {
- // Also write the port number for state server
- dataOut.writeInt(stateServerSocketPort)
+ // Also write the port/path number for state server
+ if (isUnixDomainSock) {
+ dataOut.writeInt(-1)
+ PythonWorkerUtils.writeUTF(stateServerSocketPath, dataOut)
+ } else {
+ dataOut.writeInt(stateServerSocketPort)
+ }
PythonWorkerUtils.writeUTF(groupingKeySchema.json, dataOut)
dataOut.flush()
@@ -285,13 +298,13 @@ class TransformWithStateInPandasPythonPreInitRunner(
daemonThread = new Thread {
override def run(): Unit = {
try {
- new TransformWithStateInPandasStateServer(stateServerSocket, processorHandleImpl,
+ new TransformWithStateInPySparkStateServer(stateServerSocket, processorHandleImpl,
groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames = true,
largeVarTypes = sqlConf.arrowUseLargeVarTypes,
- sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch).run()
+ sqlConf.arrowTransformWithStateInPySparkMaxStateRecordsPerBatch).run()
} catch {
case e: Exception =>
- throw new SparkException("TransformWithStateInPandas state server " +
+ throw new SparkException("TransformWithStateInPySpark state server " +
"daemon thread exited unexpectedly (crashed)", e)
}
}
@@ -303,18 +316,31 @@ class TransformWithStateInPandasPythonPreInitRunner(
}
/**
- * TransformWithStateInPandas Python runner utils functions for handling a state server
+ * TransformWithStateInPySpark Python runner utils functions for handling a state server
* in a new daemon thread.
*/
-trait TransformWithStateInPandasPythonRunnerUtils extends Logging {
- protected var stateServerSocketPort: Int = 0
- protected var stateServerSocket: ServerSocket = null
+trait TransformWithStateInPySparkPythonRunnerUtils extends Logging {
+ protected val isUnixDomainSock: Boolean = SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+ protected var stateServerSocketPort: Int = -1
+ protected var stateServerSocketPath: String = null
+ protected var stateServerSocket: ServerSocketChannel = null
protected def initStateServer(): Unit = {
var failed = false
try {
- stateServerSocket = new ServerSocket(/* port = */ 0,
- /* backlog = */ 1)
- stateServerSocketPort = stateServerSocket.getLocalPort
+ if (isUnixDomainSock) {
+ val sockPath = new File(
+ SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+ .getOrElse(System.getProperty("java.io.tmpdir")),
+ s".${UUID.randomUUID()}.sock")
+ stateServerSocket = ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+ stateServerSocket.bind(UnixDomainSocketAddress.of(sockPath.getPath), 1)
+ sockPath.deleteOnExit()
+ stateServerSocketPath = sockPath.getPath
+ } else {
+ stateServerSocket = ServerSocketChannel.open()
+ .bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+ stateServerSocketPort = stateServerSocket.socket().getLocalPort
+ }
} catch {
case e: Throwable =>
failed = true
@@ -326,10 +352,13 @@ trait TransformWithStateInPandasPythonRunnerUtils extends Logging {
}
}
- protected def closeServerSocketChannelSilently(stateServerSocket: ServerSocket): Unit = {
+ protected def closeServerSocketChannelSilently(stateServerSocket: ServerSocketChannel): Unit = {
try {
logInfo(log"closing the state server socket")
stateServerSocket.close()
+ if (stateServerSocketPath != null) {
+ new File(stateServerSocketPath).delete
+ }
} catch {
case e: Exception =>
logError(log"failed to close state server socket", e)
@@ -337,7 +366,7 @@ trait TransformWithStateInPandasPythonRunnerUtils extends Logging {
}
}
-object TransformWithStateInPandasPythonRunner {
+object TransformWithStateInPySparkPythonRunner {
type InType = (InternalRow, Iterator[InternalRow])
type GroupedInType = (InternalRow, Iterator[InternalRow], Iterator[InternalRow])
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
similarity index 86%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
index f665db8b5b120..a2c4d130ef319 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
@@ -18,16 +18,19 @@
package org.apache.spark.sql.execution.python.streaming
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException}
-import java.net.ServerSocket
+import java.nio.channels.{Channels, ServerSocketChannel}
import java.time.Duration
import scala.collection.mutable
+import scala.jdk.CollectionConverters._
import com.google.protobuf.ByteString
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.spark.SparkEnv
import org.apache.spark.internal.{Logging, LogKeys, MDC}
+import org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
@@ -36,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleImplBase, StatefulProcessorHandleState, StateVariableType}
import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall, StateRequest, StateResponse, StateResponseWithLongTypeVal, StateResponseWithStringTypeVal, StateVariableRequest, TimerRequest, TimerStateCallCommand, TimerValueRequest, UtilsRequest, ValueStateCall}
+import org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponseWithListGet
import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, ValueState}
import org.apache.spark.sql.types.{BinaryType, LongType, StructField, StructType}
import org.apache.spark.sql.util.ArrowUtils
@@ -43,7 +47,7 @@ import org.apache.spark.util.Utils
/**
* This class is used to handle the state requests from the Python side. It runs on a separate
- * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket
+ * thread spawned by TransformWithStateInPySparkStateRunner per task. It opens a dedicated socket
* to process/transfer state related info which is shut down when task finishes or there's an error
* on opening the socket. It processes following state requests and return responses to the
* Python side:
@@ -51,19 +55,19 @@ import org.apache.spark.util.Utils
* - Stateful processor requests.
* - Requests for managing state variables (e.g. valueState).
*/
-class TransformWithStateInPandasStateServer(
- stateServerSocket: ServerSocket,
+class TransformWithStateInPySparkStateServer(
+ stateServerSocket: ServerSocketChannel,
statefulProcessorHandle: StatefulProcessorHandleImplBase,
groupingKeySchema: StructType,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean,
- arrowTransformWithStateInPandasMaxRecordsPerBatch: Int,
+ arrowTransformWithStateInPySparkMaxRecordsPerBatch: Int,
batchTimestampMs: Option[Long] = None,
eventTimeWatermarkForEviction: Option[Long] = None,
outputStreamForTest: DataOutputStream = null,
valueStateMapForTest: mutable.HashMap[String, ValueStateInfo] = null,
- deserializerForTest: TransformWithStateInPandasDeserializer = null,
+ deserializerForTest: TransformWithStateInPySparkDeserializer = null,
arrowStreamWriterForTest: BaseStreamingArrowWriter = null,
listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null,
iteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null,
@@ -80,6 +84,10 @@ class TransformWithStateInPandasStateServer(
private var inputStream: DataInputStream = _
private var outputStream: DataOutputStream = outputStreamForTest
+ private val isUnixDomainSock = Option(SparkEnv.get)
+ .map(_.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED))
+ .getOrElse(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.defaultValue.get)
+
/** State variable related class variables */
// A map to store the value state name -> (value state, schema, value row deserializer) mapping.
private val valueStates = if (valueStateMapForTest != null) {
@@ -114,7 +122,7 @@ class TransformWithStateInPandasStateServer(
// A map to store the iterator id -> Iterator[(Row, Row)] mapping. This is to keep track of the
// current key-value iterator position for each iterator id in a map state for a grouping key in
// case user tries to fetch another state variable before the current iterator is exhausted.
- private var keyValueIterators = if (keyValueIteratorMapForTest != null) {
+ private val keyValueIterators = if (keyValueIteratorMapForTest != null) {
keyValueIteratorMapForTest
} else {
new mutable.HashMap[String, Iterator[(Row, Row)]]()
@@ -138,10 +146,22 @@ class TransformWithStateInPandasStateServer(
def run(): Unit = {
val listeningSocket = stateServerSocket.accept()
+
+ // SPARK-51667: We have a pattern of sending messages continuously from one side
+ // (Python -> JVM, and vice versa) before getting response from other side. Since most
+ // messages we are sending are small, this triggers the bad combination of Nagle's algorithm
+ // and delayed ACKs, which can cause a significant delay on the latency.
+ // See SPARK-51667 for more details on how this can be a problem.
+ //
+ // Disabling either would work, but it's more common to disable Nagle's algorithm; there is
+ // lot less reference to disabling delayed ACKs, while there are lots of resources to
+ // disable Nagle's algorithm.
+ if (!isUnixDomainSock) listeningSocket.socket().setTcpNoDelay(true)
+
inputStream = new DataInputStream(
- new BufferedInputStream(listeningSocket.getInputStream))
+ new BufferedInputStream(Channels.newInputStream(listeningSocket)))
outputStream = new DataOutputStream(
- new BufferedOutputStream(listeningSocket.getOutputStream)
+ new BufferedOutputStream(Channels.newOutputStream(listeningSocket))
)
while (listeningSocket.isConnected &&
@@ -241,7 +261,7 @@ class TransformWithStateInPandasStateServer(
Option(statefulProcessorHandle
.asInstanceOf[StatefulProcessorHandleImpl].getExpiredTimers(expiryTimestamp))
}
- // expiryTimestampIter could be None in the TWSPandasServerSuite
+ // expiryTimestampIter could be None in the TWSPySparkServerSuite
if (!expiryTimestampIter.isDefined || !expiryTimestampIter.get.hasNext) {
// iterator is exhausted, signal the end of iterator on python client
sendResponse(1)
@@ -452,7 +472,7 @@ class TransformWithStateInPandasStateServer(
val deserializer = if (deserializerForTest != null) {
deserializerForTest
} else {
- new TransformWithStateInPandasDeserializer(listStateInfo.deserializer)
+ new TransformWithStateInPySparkDeserializer(listStateInfo.deserializer)
}
message.getMethodCase match {
case ListStateCall.MethodCase.EXISTS =>
@@ -463,7 +483,17 @@ class TransformWithStateInPandasStateServer(
sendResponse(2, s"state $stateName doesn't exist")
}
case ListStateCall.MethodCase.LISTSTATEPUT =>
- val rows = deserializer.readArrowBatches(inputStream)
+ val rows = if (message.getListStatePut.getFetchWithArrow) {
+ deserializer.readArrowBatches(inputStream)
+ } else {
+ val elements = message.getListStatePut.getValueList.asScala
+ elements.map { e =>
+ PythonSQLUtils.toJVMRow(
+ e.toByteArray,
+ listStateInfo.schema,
+ listStateInfo.deserializer)
+ }
+ }
listStateInfo.listState.put(rows.toArray)
sendResponse(0)
case ListStateCall.MethodCase.LISTSTATEGET =>
@@ -475,12 +505,9 @@ class TransformWithStateInPandasStateServer(
}
if (!iteratorOption.get.hasNext) {
sendResponse(2, s"List state $stateName doesn't contain any value.")
- return
} else {
- sendResponse(0)
+ sendResponseWithListGet(0, iter = iteratorOption.get)
}
- sendIteratorAsArrowBatches(iteratorOption.get, listStateInfo.schema,
- arrowStreamWriterForTest) { data => listStateInfo.serializer(data)}
case ListStateCall.MethodCase.APPENDVALUE =>
val byteArray = message.getAppendValue.getValue.toByteArray
val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema,
@@ -488,7 +515,17 @@ class TransformWithStateInPandasStateServer(
listStateInfo.listState.appendValue(newRow)
sendResponse(0)
case ListStateCall.MethodCase.APPENDLIST =>
- val rows = deserializer.readArrowBatches(inputStream)
+ val rows = if (message.getAppendList.getFetchWithArrow) {
+ deserializer.readArrowBatches(inputStream)
+ } else {
+ val elements = message.getAppendList.getValueList.asScala
+ elements.map { e =>
+ PythonSQLUtils.toJVMRow(
+ e.toByteArray,
+ listStateInfo.schema,
+ listStateInfo.deserializer)
+ }
+ }
listStateInfo.listState.appendList(rows.toArray)
sendResponse(0)
case ListStateCall.MethodCase.CLEAR =>
@@ -499,6 +536,28 @@ class TransformWithStateInPandasStateServer(
}
}
+ private def sendIteratorForListState(iter: Iterator[Row]): Unit = {
+ // Only write a single batch in each GET request. Stops writing row if rowCount reaches
+ // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case
+ // when there are multiple state variables, user tries to access a different state variable
+ // while the current state variable is not exhausted yet.
+ var rowCount = 0
+ while (iter.hasNext && rowCount < arrowTransformWithStateInPySparkMaxRecordsPerBatch) {
+ val data = iter.next()
+
+ // Serialize the value row as a byte array
+ val valueBytes = PythonSQLUtils.toPyRow(data)
+ val lenBytes = valueBytes.length
+
+ outputStream.writeInt(lenBytes)
+ outputStream.write(valueBytes)
+
+ rowCount += 1
+ }
+ outputStream.writeInt(-1)
+ outputStream.flush()
+ }
+
private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
val stateName = message.getStateName
if (!mapStates.contains(stateName)) {
@@ -612,6 +671,9 @@ class TransformWithStateInPandasStateServer(
mapStateInfo.keyDeserializer)
mapStateInfo.mapState.removeKey(keyRow)
sendResponse(0)
+ case MapStateCall.MethodCase.CLEAR =>
+ mapStateInfo.mapState.clear()
+ sendResponse(0)
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
@@ -730,6 +792,46 @@ class TransformWithStateInPandasStateServer(
outputStream.write(responseMessageBytes)
}
+ def sendResponseWithListGet(
+ status: Int,
+ errorMessage: String = null,
+ iter: Iterator[Row] = null): Unit = {
+ val responseMessageBuilder = StateResponseWithListGet.newBuilder()
+ .setStatusCode(status)
+ if (status != 0 && errorMessage != null) {
+ responseMessageBuilder.setErrorMessage(errorMessage)
+ }
+
+ if (status == 0) {
+ // Only write a single batch in each GET request. Stops writing row if rowCount reaches
+ // the arrowTransformWithStateInPySparkMaxRecordsPerBatch limit. This is to handle a case
+ // when there are multiple state variables, user tries to access a different state variable
+ // while the current state variable is not exhausted yet.
+ var rowCount = 0
+ while (iter.hasNext && rowCount < arrowTransformWithStateInPySparkMaxRecordsPerBatch) {
+ val data = iter.next()
+
+ // Serialize the value row as a byte array
+ val valueBytes = PythonSQLUtils.toPyRow(data)
+
+ responseMessageBuilder.addValue(ByteString.copyFrom(valueBytes))
+
+ rowCount += 1
+ }
+
+ assert(rowCount > 0, s"rowCount should be greater than 0 when status code is 0, " +
+ s"iter.hasNext ${iter.hasNext}")
+
+ responseMessageBuilder.setRequireNextFetch(iter.hasNext)
+ }
+
+ val responseMessage = responseMessageBuilder.build()
+ val responseMessageBytes = responseMessage.toByteArray
+ val byteLength = responseMessageBytes.length
+ outputStream.writeInt(byteLength)
+ outputStream.write(responseMessageBytes)
+ }
+
def sendIteratorAsArrowBatches[T](
iter: Iterator[T],
outputSchema: StructType,
@@ -738,21 +840,21 @@ class TransformWithStateInPandasStateServer(
val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId,
errorOnDuplicatedFieldNames, largeVarTypes)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
- s"stdout writer for transformWithStateInPandas state socket", 0, Long.MaxValue)
+ s"stdout writer for transformWithStateInPySpark state socket", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val writer = new ArrowStreamWriter(root, null, outputStream)
val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
arrowStreamWriterForTest
} else {
new BaseStreamingArrowWriter(root, writer,
- arrowTransformWithStateInPandasMaxRecordsPerBatch)
+ arrowTransformWithStateInPySparkMaxRecordsPerBatch)
}
// Only write a single batch in each GET request. Stops writing row if rowCount reaches
- // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case
+ // the arrowTransformWithStateInPySparkMaxRecordsPerBatch limit. This is to handle a case
// when there are multiple state variables, user tries to access a different state variable
// while the current state variable is not exhausted yet.
var rowCount = 0
- while (iter.hasNext && rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) {
+ while (iter.hasNext && rowCount < arrowTransformWithStateInPySparkMaxRecordsPerBatch) {
val data = iter.next()
val internalRow = func(data)
arrowStreamWriter.writeRow(internalRow)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 89fc69cd2bdd9..e812af229524f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -20,9 +20,8 @@ package org.apache.spark.sql.execution.stat
import java.util.Locale
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Column, Row}
+import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
@@ -68,44 +67,35 @@ object StatFunctions extends Logging {
relativeError: Double): Seq[Seq[Double]] = {
require(relativeError >= 0,
s"Relative Error must be non-negative but got $relativeError")
+ require(probabilities.forall(p => p >= 0 && p <= 1.0),
+ "percentile should be in the range [0.0, 1.0]")
val columns: Seq[Column] = cols.map { colName =>
val field = df.resolve(colName)
require(field.dataType.isInstanceOf[NumericType],
s"Quantile calculation for column $colName with data type ${field.dataType}" +
" is not supported.")
- Column(colName).cast(DoubleType)
+ col(colName).cast(DoubleType)
}
- val emptySummaries = Array.fill(cols.size)(
- new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError))
- // Note that it works more or less by accident as `rdd.aggregate` is not a pure function:
- // this function returns the same array as given in the input (because `aggregate` reuses
- // the same argument).
- def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = {
- var i = 0
- while (i < summaries.length) {
- if (!row.isNullAt(i)) {
- val v = row.getDouble(i)
- if (!v.isNaN) summaries(i) = summaries(i).insert(v)
- }
- i += 1
- }
- summaries
- }
-
- def merge(
- sum1: Array[QuantileSummaries],
- sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = {
- sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) }
+ // approx_percentile needs integer accuracy
+ val accuracy = if (relativeError == 0.0) {
+ Int.MaxValue
+ } else {
+ math.min(Int.MaxValue, (1.0 / relativeError).ceil.toLong).toInt
}
- val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
- summaries.map {
- summary => summary.query(probabilities) match {
- case Some(q) => q
- case None => Seq()
+ val results = Array.fill(cols.size)(Seq.empty[Double])
+ df.select(posexplode(array(columns: _*)).as(Seq("index", "value")))
+ .where(!isnull(col("value")) && !isnan(col("value")))
+ .groupBy("index")
+ .agg(approx_percentile(col("value"), lit(probabilities), lit(accuracy)))
+ .collect()
+ .foreach { row =>
+ val index = row.getInt(0)
+ val quantiles = row.getSeq[Double](1)
+ results(index) = quantiles
}
- }.toImmutableArraySeq
+ results.toImmutableArraySeq
}
/** Calculate the Pearson Correlation Coefficient for the given columns */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index f0debce44e376..465973cabe587 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -315,7 +315,7 @@ class FileStreamSource(
className = fileFormatClassName,
options = optionsForInnerDataSource)
Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation(
- checkFilesExist = false), isStreaming = true))
+ checkFilesExist = false, readOnly = true), isStreaming = true))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 246057a5a9d0a..b6701182d7e06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.{CommandExecutionMode, LocalLimitExec, Que
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec}
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
-import org.apache.spark.sql.execution.python.streaming.{FlatMapGroupsInPandasWithStateExec, TransformWithStateInPandasExec}
+import org.apache.spark.sql.execution.python.streaming.{FlatMapGroupsInPandasWithStateExec, TransformWithStateInPySparkExec}
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter, StateSchemaBroadcast, StateSchemaMetadata}
import org.apache.spark.sql.internal.SQLConf
@@ -71,7 +71,8 @@ class IncrementalExecution(
MutableMap[Long, Array[Array[String]]] = MutableMap[Long, Array[Array[String]]](),
val stateSchemaMetadatas: MutableMap[Long, StateSchemaBroadcast] =
MutableMap[Long, StateSchemaBroadcast](),
- mode: CommandExecutionMode.Value = CommandExecutionMode.ALL)
+ mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
+ val isTerminatingTrigger: Boolean = false)
extends QueryExecution(sparkSession, logicalPlan, mode = mode) with Logging {
// Modified planner with stateful operations.
@@ -91,7 +92,7 @@ class IncrementalExecution(
StreamingDeduplicationStrategy ::
StreamingGlobalLimitStrategy(outputMode) ::
StreamingTransformWithStateStrategy ::
- TransformWithStateInPandasStrategy :: Nil
+ TransformWithStateInPySparkStrategy :: Nil
}
private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf()
@@ -222,9 +223,12 @@ class IncrementalExecution(
// filepath, and write this path out in the OperatorStateMetadata file
case statefulOp: StatefulOperator if isFirstBatch =>
val stateSchemaVersion = statefulOp match {
- case _: TransformWithStateExec | _: TransformWithStateInPandasExec =>
+ case _: TransformWithStateExec | _: TransformWithStateInPySparkExec =>
sparkSession.sessionState.conf.
getConf(SQLConf.STREAMING_TRANSFORM_WITH_STATE_OP_STATE_SCHEMA_VERSION)
+ case _: StreamingSymmetricHashJoinExec =>
+ sparkSession.sessionState.conf.
+ getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION)
case _ => STATE_SCHEMA_DEFAULT_VERSION
}
val schemaValidationResult = statefulOp.
@@ -278,7 +282,7 @@ class IncrementalExecution(
case exec: TransformWithStateExec =>
exec.copy(stateInfo = Some(exec.getStateInfo.copy(
stateSchemaMetadata = Some(stateSchemaBroadcast))))
- case exec: TransformWithStateInPandasExec =>
+ case exec: TransformWithStateInPySparkExec =>
exec.copy(stateInfo = Some(exec.getStateInfo.copy(
stateSchemaMetadata = Some(stateSchemaBroadcast))))
// Add other cases if needed for different StateStoreWriter implementations
@@ -372,7 +376,7 @@ class IncrementalExecution(
hasInitialState = hasInitialState
)
- case t: TransformWithStateInPandasExec =>
+ case t: TransformWithStateInPySparkExec =>
val hasInitialState = (currentBatchId == 0L && t.hasInitialState)
t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
@@ -474,11 +478,11 @@ class IncrementalExecution(
// UpdateEventTimeColumnExec is used to tag the eventTime column, and validate
// emitted rows adhere to watermark in the output of transformWithStateInp.
- // Hence, this node shares the same watermark value as TransformWithStateInPandasExec.
+ // Hence, this node shares the same watermark value as TransformWithStateInPySparkExec.
// This is the same as above in TransformWithStateExec.
- // The only difference is TransformWithStateInPandasExec is analysed slightly different
+ // The only difference is TransformWithStateInPySparkExec is analysed slightly different
// with no SerializeFromObjectExec wrapper.
- case UpdateEventTimeColumnExec(eventTime, delay, None, t: TransformWithStateInPandasExec)
+ case UpdateEventTimeColumnExec(eventTime, delay, None, t: TransformWithStateInPySparkExec)
if t.stateInfo.isDefined =>
val stateInfo = t.stateInfo.get
val iwLateEvents = inputWatermarkForLateEvents(stateInfo)
@@ -496,7 +500,7 @@ class IncrementalExecution(
eventTimeWatermarkForEviction = inputWatermarkForEviction(t.stateInfo.get)
)
- case t: TransformWithStateInPandasExec if t.stateInfo.isDefined =>
+ case t: TransformWithStateInPySparkExec if t.stateInfo.isDefined =>
t.copy(
eventTimeWatermarkForLateEvents = inputWatermarkForLateEvents(t.stateInfo.get),
eventTimeWatermarkForEviction = inputWatermarkForEviction(t.stateInfo.get)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index fe06cbb19c3a1..1dd70ad985cc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -501,6 +501,9 @@ class MicroBatchExecution(
* i.e., committedBatchId + 1 */
commitLog.getLatest() match {
case Some((latestCommittedBatchId, commitMetadata)) =>
+ commitMetadata.stateUniqueIds.foreach {
+ stateUniqueIds => currentStateStoreCkptId ++= stateUniqueIds
+ }
if (latestBatchId == latestCommittedBatchId) {
/* The last batch was successfully committed, so we can safely process a
* new next batch but first:
@@ -520,9 +523,6 @@ class MicroBatchExecution(
execCtx.startOffsets ++= execCtx.endOffsets
watermarkTracker.setWatermark(
math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs))
- commitMetadata.stateUniqueIds.foreach {
- stateUniqueIds => currentStateStoreCkptId ++= stateUniqueIds
- }
} else if (latestCommittedBatchId == latestBatchId - 1) {
execCtx.endOffsets.foreach {
case (source: Source, end: Offset) =>
@@ -858,7 +858,8 @@ class MicroBatchExecution(
watermarkPropagator,
execCtx.previousContext.isEmpty,
currentStateStoreCkptId,
- stateSchemaMetadatas)
+ stateSchemaMetadatas,
+ isTerminatingTrigger = trigger.isInstanceOf[AvailableNowTrigger.type])
execCtx.executionPlan.executedPlan // Force the lazy generation of execution plan
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 3ac07cf1d7308..dc04ba3331e71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream}
import org.apache.spark.sql.execution.{QueryExecution, StreamSourceAwareSparkPlan}
import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress}
+import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent}
import org.apache.spark.util.{Clock, Utils}
@@ -61,6 +62,12 @@ class ProgressReporter(
val noDataProgressEventInterval: Long =
sparkSession.sessionState.conf.streamingNoDataProgressEventInterval
+ val coordinatorReportSnapshotUploadLag: Boolean =
+ sparkSession.sessionState.conf.stateStoreCoordinatorReportSnapshotUploadLag
+
+ val stateStoreCoordinator: StateStoreCoordinatorRef =
+ sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator
+
private val timestampFormat =
DateTimeFormatter
.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601
@@ -283,6 +290,17 @@ abstract class ProgressContext(
progressReporter.lastNoExecutionProgressEventTime = triggerClock.getTimeMillis()
progressReporter.updateProgress(newProgress)
+ // Ask the state store coordinator to log all lagging state stores
+ if (progressReporter.coordinatorReportSnapshotUploadLag) {
+ val latestVersion = lastEpochId + 1
+ progressReporter.stateStoreCoordinator
+ .logLaggingStateStores(
+ lastExecution.runId,
+ latestVersion,
+ lastExecution.isTerminatingTrigger
+ )
+ }
+
// Update the value since this trigger executes a batch successfully.
this.execStatsOnLatestExecutedBatch = Some(execStats)
@@ -559,7 +577,7 @@ abstract class ProgressContext(
hasNewData: Boolean,
sourceToNumInputRows: Map[SparkDataStream, Long],
lastExecution: IncrementalExecution): ExecutionStats = {
- val hasEventTime = progressReporter.logicalPlan().collect {
+ val hasEventTime = progressReporter.logicalPlan().collectFirst {
case e: EventTimeWatermark => e
}.nonEmpty
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index f1f0ddf206c60..3cf3286fafb80 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -458,7 +458,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
stateName,
keyExprEnc.schema
)
- null.asInstanceOf[ValueState[T]]
+ new InvalidHandleValueState[T](stateName)
}
override def getListState[T](
@@ -492,7 +492,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
stateName,
keyExprEnc.schema
)
- null.asInstanceOf[ListState[T]]
+ new InvalidHandleListState[T](stateName)
}
override def getMapState[K, V](
@@ -522,7 +522,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val stateVariableInfo = TransformWithStateVariableUtils.
getMapState(stateName, ttlEnabled = ttlEnabled)
stateVariableInfos.put(stateName, stateVariableInfo)
- null.asInstanceOf[MapState[K, V]]
+ new InvalidHandleMapState[K, V](stateName)
}
/**
@@ -655,3 +655,43 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("delete_if_exists", PRE_INIT)
}
}
+
+private[sql] trait InvalidHandleState {
+ protected val stateName: String
+
+ protected def throwInitPhaseError(operation: String): Nothing = {
+ throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(
+ s"$stateName.$operation", PRE_INIT.toString)
+ }
+}
+
+private[sql] class InvalidHandleValueState[S](override val stateName: String)
+ extends ValueState[S] with InvalidHandleState {
+ override def exists(): Boolean = throwInitPhaseError("exists")
+ override def get(): S = throwInitPhaseError("get")
+ override def update(newState: S): Unit = throwInitPhaseError("update")
+ override def clear(): Unit = throwInitPhaseError("clear")
+}
+
+private[sql] class InvalidHandleListState[S](override val stateName: String)
+ extends ListState[S] with InvalidHandleState {
+ override def exists(): Boolean = throwInitPhaseError("exists")
+ override def get(): Iterator[S] = throwInitPhaseError("get")
+ override def put(newState: Array[S]): Unit = throwInitPhaseError("put")
+ override def appendValue(newState: S): Unit = throwInitPhaseError("appendValue")
+ override def appendList(newState: Array[S]): Unit = throwInitPhaseError("appendList")
+ override def clear(): Unit = throwInitPhaseError("clear")
+}
+
+private[sql] class InvalidHandleMapState[K, V](override val stateName: String)
+ extends MapState[K, V] with InvalidHandleState {
+ override def exists(): Boolean = throwInitPhaseError("exists")
+ override def getValue(key: K): V = throwInitPhaseError("getValue")
+ override def containsKey(key: K): Boolean = throwInitPhaseError("containsKey")
+ override def updateValue(key: K, value: V): Unit = throwInitPhaseError("updateValue")
+ override def iterator(): Iterator[(K, V)] = throwInitPhaseError("iterator")
+ override def keys(): Iterator[K] = throwInitPhaseError("keys")
+ override def values(): Iterator[V] = throwInitPhaseError("values")
+ override def removeKey(key: K): Unit = throwInitPhaseError("removeKey")
+ override def clear(): Unit = throwInitPhaseError("clear")
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 7c8ba260b88af..ccb8b69f7c831 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -140,7 +140,7 @@ case class StreamingSymmetricHashJoinExec(
stateWatermarkPredicates: JoinStateWatermarkPredicates,
stateFormatVersion: Int,
left: SparkPlan,
- right: SparkPlan) extends BinaryExecNode with StateStoreWriter {
+ right: SparkPlan) extends BinaryExecNode with StateStoreWriter with SchemaValidationUtils {
def this(
leftKeys: Seq[Expression],
@@ -196,6 +196,19 @@ case class StreamingSymmetricHashJoinExec(
private val allowMultipleStatefulOperators =
conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)
+ private val useVirtualColumnFamilies = stateFormatVersion == 3
+
+ // Determine the store names and metadata version based on format version
+ private val (numStoresPerPartition, _stateStoreNames, _operatorStateMetadataVersion) =
+ if (useVirtualColumnFamilies) {
+ // We have 1 state store using virtual column families
+ (1, Seq(StateStoreId.DEFAULT_STORE_NAME), 2)
+ } else {
+ // We have 4 state stores (2 on either side of the join)
+ val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+ (stateStoreNames.size, stateStoreNames, 1)
+ }
+
val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length)
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)
@@ -224,16 +237,37 @@ case class StreamingSymmetricHashJoinExec(
override def shortName: String = "symmetricHashJoin"
- override val stateStoreNames: Seq[String] =
- SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+ override val stateStoreNames: Seq[String] = _stateStoreNames
+
+ override def operatorStateMetadataVersion: Int = _operatorStateMetadataVersion
override def operatorStateMetadata(
stateSchemaPaths: List[List[String]] = List.empty): OperatorStateMetadata = {
val info = getStateInfo
val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
- val stateStoreInfo =
- stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray
- OperatorStateMetadataV1(operatorInfo, stateStoreInfo)
+ if (useVirtualColumnFamilies) {
+ // Use MetadataV2 for join operators that use virtual column families
+ val stateStoreInfo = stateStoreNames.zip(stateSchemaPaths).map {
+ case (storeName, schemaPath) =>
+ StateStoreMetadataV2(storeName, 0, info.numPartitions, schemaPath)
+ }.toArray
+ val properties = StreamingJoinOperatorProperties(useVirtualColumnFamilies)
+ OperatorStateMetadataV2(operatorInfo, stateStoreInfo, properties.json)
+ } else {
+ val stateStoreInfo =
+ stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray
+ OperatorStateMetadataV1(operatorInfo, stateStoreInfo)
+ }
+ }
+
+ override def getColFamilySchemas(
+ shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = {
+ assert(useVirtualColumnFamilies)
+ // We only have one state store for the join, but there are four distinct schemas
+ SymmetricHashJoinStateManager
+ .getSchemasForStateStoreWithColFamily(LeftSide, left.output, leftKeys, stateFormatVersion) ++
+ SymmetricHashJoinStateManager
+ .getSchemasForStateStoreWithColFamily(RightSide, right.output, rightKeys, stateFormatVersion)
}
override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
@@ -252,29 +286,37 @@ case class StreamingSymmetricHashJoinExec(
hadoopConf: Configuration,
batchId: Long,
stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
- var result: Map[String, (StructType, StructType)] = Map.empty
- // get state schema for state stores on left side of the join
- result ++= SymmetricHashJoinStateManager.getSchemaForStateStores(LeftSide,
- left.output, leftKeys, stateFormatVersion)
-
- // get state schema for state stores on right side of the join
- result ++= SymmetricHashJoinStateManager.getSchemaForStateStores(RightSide,
- right.output, rightKeys, stateFormatVersion)
-
- // validate and maybe evolve schema for all state stores across both sides of the join
- result.map { case (stateStoreName, (keySchema, valueSchema)) =>
- // we have to add the default column family schema because the RocksDBStateEncoder
- // expects this entry to be present in the stateSchemaProvider.
- val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
- keySchema, 0, valueSchema))
- StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
- newStateSchema, session.sessionState, stateSchemaVersion, storeName = stateStoreName)
- }.toList
+ if (useVirtualColumnFamilies) {
+ val info = getStateInfo
+ val stateSchemaDir = stateSchemaDirPath()
+
+ validateAndWriteStateSchema(
+ hadoopConf, batchId, stateSchemaVersion, info, stateSchemaDir, session
+ )
+ } else {
+ var result: Map[String, (StructType, StructType)] = Map.empty
+ // get state schema for state stores on left side of the join
+ result ++= SymmetricHashJoinStateManager.getSchemaForStateStores(LeftSide,
+ left.output, leftKeys, stateFormatVersion)
+
+ // get state schema for state stores on right side of the join
+ result ++= SymmetricHashJoinStateManager.getSchemaForStateStores(RightSide,
+ right.output, rightKeys, stateFormatVersion)
+
+ // validate and maybe evolve schema for all state stores across both sides of the join
+ result.map { case (stateStoreName, (keySchema, valueSchema)) =>
+ // we have to add the default column family schema because the RocksDBStateEncoder
+ // expects this entry to be present in the stateSchemaProvider.
+ val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
+ keySchema, 0, valueSchema))
+ StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
+ newStateSchema, session.sessionState, stateSchemaVersion, storeName = stateStoreName)
+ }.toList
+ }
}
protected override def doExecute(): RDD[InternalRow] = {
val stateStoreCoord = session.sessionState.streamingQueryManager.stateStoreCoordinator
- val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
metrics // initialize metrics
left.execute().stateStoreAwareZipPartitions(
right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions)
@@ -307,19 +349,25 @@ case class StreamingSymmetricHashJoinExec(
assert(stateInfo.isDefined, "State info not defined")
val checkpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
- partitionId, stateInfo.get)
+ partitionId, stateInfo.get, useVirtualColumnFamilies)
val inputSchema = left.output ++ right.output
val postJoinFilter =
Predicate.create(condition.bothSides.getOrElse(Literal(true)), inputSchema).eval _
- val leftSideJoiner = new OneSideHashJoiner(
- LeftSide, left.output, leftKeys, leftInputIter,
- condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left, partitionId,
- checkpointIds.left.keyToNumValues, checkpointIds.left.valueToNumKeys, skippedNullValueCount)
- val rightSideJoiner = new OneSideHashJoiner(
- RightSide, right.output, rightKeys, rightInputIter,
- condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId,
- checkpointIds.right.keyToNumValues, checkpointIds.right.valueToNumKeys, skippedNullValueCount)
+ // Create left and right side hash joiners and store in the joiner manager.
+ // Both sides should use the same store generator if we are re-using the same store instance.
+ val joinStateManagerStoreGenerator = new JoinStateManagerStoreGenerator()
+ val joinerManager = OneSideHashJoinerManager(
+ new OneSideHashJoiner(
+ LeftSide, left.output, leftKeys, leftInputIter,
+ condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left, partitionId,
+ checkpointIds.left.keyToNumValues, checkpointIds.left.valueToNumKeys,
+ skippedNullValueCount, joinStateManagerStoreGenerator),
+ new OneSideHashJoiner(
+ RightSide, right.output, rightKeys, rightInputIter,
+ condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId,
+ checkpointIds.right.keyToNumValues, checkpointIds.right.valueToNumKeys,
+ skippedNullValueCount, joinStateManagerStoreGenerator))
// Join one side input using the other side's buffered/state rows. Here is how it is done.
//
@@ -338,12 +386,14 @@ case class StreamingSymmetricHashJoinExec(
// - Left Semi Join: generates all stored left input rows, from matching new right input
// with stored left input, and also stores all the right input. Note only first-time
// matched left input rows will be generated, this is to guarantee left semi semantics.
- val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) {
- (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched)
- }
- val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) {
- (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input)
- }
+ val leftOutputIter =
+ joinerManager.leftSideJoiner.storeAndJoinWithOtherSide(joinerManager.rightSideJoiner) {
+ (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched)
+ }
+ val rightOutputIter =
+ joinerManager.rightSideJoiner.storeAndJoinWithOtherSide(joinerManager.leftSideJoiner) {
+ (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input)
+ }
// We need to save the time that the one side hash join output iterator completes, since
// other join output counts as both update and removal time.
@@ -375,17 +425,17 @@ case class StreamingSymmetricHashJoinExec(
// flag along with row, which is set to true when there's any matching row on the right.
def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = {
- rightSideJoiner.get(leftKeyValue.key).exists { rightValue =>
+ joinerManager.rightSideJoiner.get(leftKeyValue.key).exists { rightValue =>
postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue))
}
}
val initIterFn = { () =>
- val removedRowIter = leftSideJoiner.removeOldState()
+ val removedRowIter = joinerManager.leftSideJoiner.removeOldState()
removedRowIter.filterNot { kv =>
stateFormatVersion match {
case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, kv.value))
- case 2 => kv.matched
+ case 2 | 3 => kv.matched
case _ => throwBadStateFormatVersionException()
}
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
@@ -401,17 +451,17 @@ case class StreamingSymmetricHashJoinExec(
case RightOuter =>
// See comments for left outer case.
def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = {
- leftSideJoiner.get(rightKeyValue.key).exists { leftValue =>
+ joinerManager.leftSideJoiner.get(rightKeyValue.key).exists { leftValue =>
postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
}
}
val initIterFn = { () =>
- val removedRowIter = rightSideJoiner.removeOldState()
+ val removedRowIter = joinerManager.rightSideJoiner.removeOldState()
removedRowIter.filterNot { kv =>
stateFormatVersion match {
case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value))
- case 2 => kv.matched
+ case 2 | 3 => kv.matched
case _ => throwBadStateFormatVersionException()
}
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
@@ -427,18 +477,18 @@ case class StreamingSymmetricHashJoinExec(
case FullOuter =>
lazy val isKeyToValuePairMatched = (kv: KeyToValuePair) =>
stateFormatVersion match {
- case 2 => kv.matched
+ case 2 | 3 => kv.matched
case _ => throwBadStateFormatVersionException()
}
val leftSideInitIterFn = { () =>
- val removedRowIter = leftSideJoiner.removeOldState()
+ val removedRowIter = joinerManager.leftSideJoiner.removeOldState()
removedRowIter.filterNot(isKeyToValuePairMatched)
.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
}
val rightSideInitIterFn = { () =>
- val removedRowIter = rightSideJoiner.removeOldState()
+ val removedRowIter = joinerManager.rightSideJoiner.removeOldState()
removedRowIter.filterNot(isKeyToValuePairMatched)
.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
}
@@ -493,10 +543,9 @@ case class StreamingSymmetricHashJoinExec(
// For full outer joins, we have already removed unnecessary states from both sides, so
// nothing needs to be outputted here.
val cleanupIter = joinType match {
- case Inner | LeftSemi =>
- leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
- case LeftOuter => rightSideJoiner.removeOldState()
- case RightOuter => leftSideJoiner.removeOldState()
+ case Inner | LeftSemi => joinerManager.removeOldState()
+ case LeftOuter => joinerManager.rightSideJoiner.removeOldState()
+ case RightOuter => joinerManager.leftSideJoiner.removeOldState()
case FullOuter => Iterator.empty
case _ => throwBadJoinTypeException()
}
@@ -508,31 +557,22 @@ case class StreamingSymmetricHashJoinExec(
// Commit all state changes and update state store metrics
commitTimeMs += timeTakenMs {
- val leftSideMetrics = leftSideJoiner.commitStateAndGetMetrics()
- val rightSideMetrics = rightSideJoiner.commitStateAndGetMetrics()
- val combinedMetrics = StateStoreMetrics.combine(Seq(leftSideMetrics, rightSideMetrics))
+ joinerManager.commit()
+ val combinedMetrics = joinerManager.metrics
if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) {
- val checkpointInfo = SymmetricHashJoinStateManager.mergeStateStoreCheckpointInfo(
- JoinStateStoreCkptInfo(
- leftSideJoiner.getLatestCheckpointInfo(),
- rightSideJoiner.getLatestCheckpointInfo()
- )
- )
- setStateStoreCheckpointInfo(checkpointInfo)
+ setStateStoreCheckpointInfo(joinerManager.getStateStoreCheckpointInfo)
}
// Update SQL metrics
- numUpdatedStateRows +=
- (leftSideJoiner.numUpdatedStateRows + rightSideJoiner.numUpdatedStateRows)
+ numUpdatedStateRows += joinerManager.totalNumUpdatedStateRows
numTotalStateRows += combinedMetrics.numKeys
stateMemory += combinedMetrics.memoryUsedBytes
setStoreCustomMetrics(combinedMetrics.customMetrics)
setStoreInstanceMetrics(combinedMetrics.instanceMetrics)
}
- val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide);
- setOperatorMetrics(numStateStoreInstances = stateStoreNames.length)
+ setOperatorMetrics(numStateStoreInstances = numStoresPerPartition)
}
CompletionIterator[InternalRow, Iterator[InternalRow]](
@@ -560,6 +600,11 @@ case class StreamingSymmetricHashJoinExec(
* state watermarks.
* @param oneSideStateInfo Reconstructed state info for this side
* @param partitionId A partition ID of source RDD.
+ * @param joinStateManagerStoreGenerator The state store generator responsible for getting the
+ * state store for this join side. The generator will
+ * re-use the same store for both sides when the join
+ * implementation uses virtual column families for join
+ * version 3.
*/
private class OneSideHashJoiner(
joinSide: JoinSide,
@@ -572,13 +617,14 @@ case class StreamingSymmetricHashJoinExec(
partitionId: Int,
keyToNumValuesStateStoreCkptId: Option[String],
keyWithIndexToValueStateStoreCkptId: Option[String],
- skippedNullValueCount: Option[SQLMetric]) {
+ skippedNullValueCount: Option[SQLMetric],
+ joinStateManagerStoreGenerator: JoinStateManagerStoreGenerator) {
// Filter the joined rows based on the given condition.
val preJoinFilter =
Predicate.create(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _
- private val joinStateManager = new SymmetricHashJoinStateManager(
+ private val joinStateManager = SymmetricHashJoinStateManager(
joinSide = joinSide,
inputValueAttributes = inputAttributes,
joinKeys = joinKeys,
@@ -589,7 +635,8 @@ case class StreamingSymmetricHashJoinExec(
keyToNumValuesStateStoreCkptId = keyToNumValuesStateStoreCkptId,
keyWithIndexToValueStateStoreCkptId = keyWithIndexToValueStateStoreCkptId,
stateFormatVersion = stateFormatVersion,
- skippedNullValueCount = skippedNullValueCount)
+ skippedNullValueCount = skippedNullValueCount,
+ joinStoreGenerator = joinStateManagerStoreGenerator)
private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)
@@ -763,9 +810,13 @@ case class StreamingSymmetricHashJoinExec(
}
}
- /** Commit changes to the buffer state and return the state store metrics */
- def commitStateAndGetMetrics(): StateStoreMetrics = {
+ /** Commit changes to the buffer state */
+ def commitState(): Unit = {
joinStateManager.commit()
+ }
+
+ /** Return state store metrics for state committed */
+ def getMetrics: StateStoreMetrics = {
joinStateManager.metrics
}
@@ -776,6 +827,67 @@ case class StreamingSymmetricHashJoinExec(
def numUpdatedStateRows: Long = updatedStateRowsCount
}
+ /**
+ * Case class used to manage both left and right side's joiners, combining
+ * information from both sides when necessary.
+ */
+ private case class OneSideHashJoinerManager(
+ leftSideJoiner: OneSideHashJoiner, rightSideJoiner: OneSideHashJoiner) {
+
+ def removeOldState(): Iterator[KeyToValuePair] = {
+ leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
+ }
+
+ def metrics: StateStoreMetrics = {
+ if (useVirtualColumnFamilies) {
+ // Since both sides use the same state store, the metrics are already combined
+ leftSideJoiner.getMetrics
+ } else {
+ StateStoreMetrics.combine(
+ Seq(leftSideJoiner.getMetrics, rightSideJoiner.getMetrics)
+ )
+ }
+ }
+
+ def commit(): Unit = {
+ if (useVirtualColumnFamilies) {
+ // We only have one state store for both sides to commit
+ leftSideJoiner.commitState()
+ } else {
+ // We have to commit stores used on both sides
+ leftSideJoiner.commitState()
+ rightSideJoiner.commitState()
+ }
+ }
+
+ def getStateStoreCheckpointInfo: StatefulOpStateStoreCheckpointInfo = {
+ if (useVirtualColumnFamilies) {
+ // No merging needed, both fields from getLatestCheckpointInfo() should be identical
+ val storeCheckpointInfo = leftSideJoiner.getLatestCheckpointInfo().keyToNumValues
+ StatefulOpStateStoreCheckpointInfo(
+ storeCheckpointInfo.partitionId,
+ storeCheckpointInfo.batchVersion,
+ storeCheckpointInfo.stateStoreCkptId.map(Array(_)),
+ storeCheckpointInfo.baseStateStoreCkptId.map(Array(_))
+ )
+ } else {
+ // Merge checkpoint info from both sides
+ SymmetricHashJoinStateManager.mergeStateStoreCheckpointInfo(
+ JoinStateStoreCkptInfo(
+ leftSideJoiner.getLatestCheckpointInfo(),
+ rightSideJoiner.getLatestCheckpointInfo()
+ )
+ )
+ }
+ }
+
+ def totalNumUpdatedStateRows: Long = {
+ // Regardless of join implementation, combine the number of updated state rows
+ // as these are maintained outside the state store
+ leftSideJoiner.numUpdatedStateRows + rightSideJoiner.numUpdatedStateRows
+ }
+ }
+
override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): StreamingSymmetricHashJoinExec =
copy(left = newLeft, right = newRight)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 6e0502e186597..b2702bc019222 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -381,7 +381,7 @@ case class TransformWithStateExec(
// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForDataForLateEvents match {
- case Some(predicate) =>
+ case Some(predicate) if timeMode == TimeMode.EventTime() =>
applyRemovingRowsOlderThanWatermark(iter, predicate)
case _ =>
iter
@@ -475,8 +475,9 @@ case class TransformWithStateExec(
batchId: Long,
stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
val info = getStateInfo
+ val stateSchemaDir = stateSchemaDirPath()
validateAndWriteStateSchema(hadoopConf, batchId, stateSchemaVersion,
- info, session, operatorStateMetadataVersion, conf.stateStoreEncodingFormat)
+ info, stateSchemaDir, session, operatorStateMetadataVersion, conf.stateStoreEncodingFormat)
}
/** Metadata of this stateful operator and its states stores. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
index 021a6fa1ecbdc..7b9a478b8be19 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
@@ -16,10 +16,6 @@
*/
package org.apache.spark.sql.execution.streaming
-import java.util.UUID
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
@@ -27,10 +23,8 @@ import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.StateVariableType.StateVariableType
-import org.apache.spark.sql.execution.streaming.state.{OperatorInfoV1, OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV2, StateSchemaCompatibilityChecker, StateSchemaValidationResult, StateStoreColFamilySchema, StateStoreErrors, StateStoreId, StateStoreMetadataV2}
-import org.apache.spark.sql.execution.streaming.state.StateStoreEncoding._
+import org.apache.spark.sql.execution.streaming.state.{OperatorInfoV1, OperatorStateMetadata, OperatorStateMetadataV2, StateStoreErrors, StateStoreId, StateStoreMetadataV2}
import org.apache.spark.sql.streaming.{OutputMode, TimeMode}
/**
@@ -173,11 +167,10 @@ object TransformWithStateOperatorProperties extends Logging {
* `init()` with DriverStatefulProcessorHandleImpl, and get the state schema and state metadata
* on driver during physical planning phase.
*/
-trait TransformWithStateMetadataUtils extends Logging {
+trait TransformWithStateMetadataUtils extends SchemaValidationUtils with Logging {
- // This method will return the column family schemas, and check whether the fields in the
- // schema are nullable. If Avro encoding is used, we want to enforce nullability
- def getColFamilySchemas(shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema]
+ // TransformWithState operators are allowed to evolve their schemas
+ override val schemaEvolutionEnabledForOperator: Boolean = true
def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo]
@@ -201,51 +194,6 @@ trait TransformWithStateMetadataUtils extends Logging {
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, operatorProperties.json)
}
- def validateAndWriteStateSchema(
- hadoopConf: Configuration,
- batchId: Long,
- stateSchemaVersion: Int,
- info: StatefulOperatorStateInfo,
- session: SparkSession,
- operatorStateMetadataVersion: Int = 2,
- stateStoreEncodingFormat: String = UnsafeRow.toString): List[StateSchemaValidationResult] = {
- assert(stateSchemaVersion >= 3)
- val usingAvro = stateStoreEncodingFormat == Avro.toString
- val newSchemas = getColFamilySchemas(usingAvro)
- val stateSchemaDir = stateSchemaDirPath(info)
- val newStateSchemaFilePath =
- new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")
- val metadataPath = new Path(info.checkpointLocation, s"${info.operatorId}")
- val metadataReader = OperatorStateMetadataReader.createReader(
- metadataPath, hadoopConf, operatorStateMetadataVersion, batchId)
- val operatorStateMetadata = try {
- metadataReader.read()
- } catch {
- // If this is the first time we are running the query, there will be no metadata
- // and this error is expected. In this case, we return None.
- case _: Exception if batchId == 0 =>
- None
- }
-
- val oldStateSchemaFilePaths: List[Path] = operatorStateMetadata match {
- case Some(metadata) =>
- metadata match {
- case v2: OperatorStateMetadataV2 =>
- v2.stateStoreInfo.head.stateSchemaFilePaths.map(new Path(_))
- case _ => List.empty
- }
- case None => List.empty
- }
- // state schema file written here, writing the new schema list we passed here
- List(StateSchemaCompatibilityChecker.
- validateAndMaybeEvolveStateSchema(info, hadoopConf,
- newSchemas.values.toList, session.sessionState, stateSchemaVersion,
- storeName = StateStoreId.DEFAULT_STORE_NAME,
- oldSchemaFilePaths = oldStateSchemaFilePaths,
- newSchemaFilePath = Some(newStateSchemaFilePath),
- schemaEvolutionEnabled = usingAvro))
- }
-
def validateNewMetadataForTWS(
oldOperatorMetadata: OperatorStateMetadata,
newOperatorMetadata: OperatorStateMetadata): Unit = {
@@ -262,14 +210,4 @@ trait TransformWithStateMetadataUtils extends Logging {
case (_, _) =>
}
}
-
- private def stateSchemaDirPath(info: StatefulOperatorStateInfo): Path = {
- val storeName = StateStoreId.DEFAULT_STORE_NAME
- val stateCheckpointPath =
- new Path(info.checkpointLocation, s"${info.operatorId.toString}")
-
- val stateSchemaPath = new Path(stateCheckpointPath, "_stateSchema")
- val storeNamePath = new Path(stateSchemaPath, storeName)
- storeNamePath
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchStream.scala
index d51f87cb1a578..dedc7c9ef7f32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchStream.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources
import org.json4s.{Formats, NoTypeHints}
import org.json4s.jackson.Serialization
-import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -94,9 +94,28 @@ class RatePerMicroBatchStream(
val (startOffset, startTimestamp) = extractOffsetAndTimestamp(start)
val (endOffset, endTimestamp) = extractOffsetAndTimestamp(end)
- assert(startOffset <= endOffset, s"startOffset($startOffset) > endOffset($endOffset)")
- assert(startTimestamp <= endTimestamp,
- s"startTimestamp($startTimestamp) > endTimestamp($endTimestamp)")
+ if (startOffset > endOffset) {
+ // This should not happen.
+ throw new SparkRuntimeException(
+ errorClass = "MALFORMED_STATE_IN_RATE_PER_MICRO_BATCH_SOURCE.INVALID_OFFSET",
+ messageParameters =
+ Map("startOffset" -> startOffset.toString, "endOffset" -> endOffset.toString))
+ }
+ if (startTimestamp > endTimestamp) {
+ // This could happen in the following scenario:
+ // 1. query starts with startingTimestamp=t1
+ // 2. query checkpoints offset for batch 0 with timestamp=t1
+ // 3. query stops, batch 0 is not committed
+ // 4. query restarts but with a new startingTimestamp=t2 (t2 > t1) and resumes batch 0
+ // Now the start offset is (offset=0, timestamp=t2) and the end offset
+ // is (offset=x, timestamp=t1)
+ throw new SparkRuntimeException(
+ errorClass = "MALFORMED_STATE_IN_RATE_PER_MICRO_BATCH_SOURCE.INVALID_TIMESTAMP",
+ messageParameters = Map(
+ "startTimestamp" -> startTimestamp.toString,
+ "endTimestamp" -> endTimestamp.toString))
+ }
+
logDebug(s"startOffset: $startOffset, startTimestamp: $startTimestamp, " +
s"endOffset: $endOffset, endTimestamp: $endTimestamp")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 8c1f2eeb41a96..98d49596d11b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.streaming.state
import java.io._
import java.util
-import java.util.Locale
-import java.util.concurrent.atomic.LongAdder
+import java.util.{Locale, UUID}
+import java.util.concurrent.atomic.{AtomicLong, LongAdder}
import scala.collection.mutable
import scala.jdk.CollectionConverters._
@@ -219,7 +219,17 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
supportedCustomMetrics.find(_.name == name).map(_ -> value)
} + (metricStateOnCurrentVersionSizeBytes -> SizeEstimator.estimate(mapToUpdate))
- StateStoreMetrics(mapToUpdate.size(), metricsFromProvider("memoryUsedBytes"), customMetrics)
+ val instanceMetrics = Map(
+ instanceMetricSnapshotLastUpload.withNewId(
+ stateStoreId.partitionId, stateStoreId.storeName) -> lastUploadedSnapshotVersion.get()
+ )
+
+ StateStoreMetrics(
+ mapToUpdate.size(),
+ metricsFromProvider("memoryUsedBytes"),
+ customMetrics,
+ instanceMetrics
+ )
}
override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = {
@@ -386,6 +396,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
Nil
}
+ override def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] =
+ Seq(instanceMetricSnapshotLastUpload)
+
private def toMessageWithContext: MessageWithContext = {
log"HDFSStateStoreProvider[id = (op=${MDC(LogKeys.OP_ID, stateStoreId.operatorId)}," +
log"part=${MDC(LogKeys.PARTITION_ID, stateStoreId.partitionId)})," +
@@ -419,6 +432,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
private val loadedMapCacheHitCount: LongAdder = new LongAdder
private val loadedMapCacheMissCount: LongAdder = new LongAdder
+ // This is updated when the maintenance task writes the snapshot file and read by the task
+ // thread. -1 represents no version has ever been uploaded.
+ private val lastUploadedSnapshotVersion: AtomicLong = new AtomicLong(-1L)
+
private lazy val metricStateOnCurrentVersionSizeBytes: StateStoreCustomSizeMetric =
StateStoreCustomSizeMetric("stateOnCurrentVersionSizeBytes",
"estimated size of state only on current version")
@@ -431,6 +448,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
StateStoreCustomSumMetric("loadedMapCacheMissCount",
"count of cache miss on states cache in provider")
+ private lazy val instanceMetricSnapshotLastUpload: StateStoreInstanceMetric =
+ StateStoreSnapshotLastUploadInstanceMetric()
+
private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
private def commitUpdates(
@@ -531,6 +551,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
val snapshotCurrentVersionMap = readSnapshotFile(version)
if (snapshotCurrentVersionMap.isDefined) {
synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) }
+
+ // Report the loaded snapshot's version to the coordinator
+ reportSnapshotUploadToCoordinator(version)
+
return snapshotCurrentVersionMap.get
}
@@ -560,6 +584,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}
synchronized { putStateIntoStateCacheMap(version, resultMap) }
+
+ // Report the last available snapshot's version to the coordinator
+ reportSnapshotUploadToCoordinator(lastAvailableVersion)
+
resultMap
}
@@ -677,6 +705,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
logInfo(log"Written snapshot file for version ${MDC(LogKeys.FILE_VERSION, version)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, this)} at ${MDC(LogKeys.FILE_NAME, targetFile)} " +
log"for ${MDC(LogKeys.OP_TYPE, opType)}")
+ // Compare and update with the version that was just uploaded.
+ lastUploadedSnapshotVersion.updateAndGet(v => Math.max(version, v))
+ // Report the snapshot upload event to the coordinator
+ reportSnapshotUploadToCoordinator(version)
}
/**
@@ -1021,6 +1053,18 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
keySchema, valueSchema)
}
+
+ /** Reports to the coordinator the store's latest snapshot version */
+ private def reportSnapshotUploadToCoordinator(version: Long): Unit = {
+ if (storeConf.reportSnapshotUploadLag) {
+ // Attach the query run ID and current timestamp to the RPC message
+ val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf))
+ val currentTimestamp = System.currentTimeMillis()
+ StateStoreProvider.coordinatorRef.foreach(
+ _.snapshotUploaded(StateStoreProviderId(stateStoreId, runId), version, currentTimestamp)
+ )
+ }
+ }
}
/** [[StateStoreChangeDataReader]] implementation for [[HDFSBackedStateStoreProvider]] */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
index 13fbeda18689d..befa3fb817224 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
@@ -24,7 +24,8 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path, PathFilter}
-import org.json4s.{Formats, NoTypeHints}
+import org.json4s.{Formats, JBool, JObject, NoTypeHints}
+import org.json4s.jackson.JsonMethods.{compact, render}
import org.json4s.jackson.Serialization
import org.apache.spark.internal.{Logging, LogKeys, MDC}
@@ -331,9 +332,12 @@ class OperatorStateMetadataV2Reader(
if (!fm.exists(offsetLog)) {
return Array.empty
}
+ // Offset Log files are numeric so we want to skip any files that don't
+ // conform to this
fm.list(offsetLog)
.filter(f => !f.getPath.getName.startsWith(".")) // ignore hidden files
- .map(_.getPath.getName.toLong).sorted
+ .flatMap(f => scala.util.Try(f.getPath.getName.toLong).toOption)
+ .sorted
}
// List the available batches in the operator metadata directory
@@ -341,7 +345,11 @@ class OperatorStateMetadataV2Reader(
if (!fm.exists(metadataDirPath)) {
return Array.empty
}
- fm.list(metadataDirPath).map(_.getPath.getName.toLong).sorted
+
+ // filter out non-numeric file names (as OperatorStateMetadataV2 file names are numeric)
+ fm.list(metadataDirPath)
+ .flatMap(f => scala.util.Try(f.getPath.getName.toLong).toOption)
+ .sorted
}
override def read(): Option[OperatorStateMetadata] = {
@@ -406,6 +414,8 @@ class OperatorStateMetadataV2FileManager(
if (thresholdBatchId != 0) {
val earliestBatchIdKept = deleteMetadataFiles(thresholdBatchId)
// we need to delete everything from 0 to (earliestBatchIdKept - 1), inclusive
+ // TODO: [SPARK-50845]: Currently, deleteSchemaFiles is a no-op since earliestBatchIdKept
+ // is always 0, and the earliest schema file to 'keep' is -1.
deleteSchemaFiles(earliestBatchIdKept - 1)
}
}
@@ -417,11 +427,19 @@ class OperatorStateMetadataV2FileManager(
commitLog.listBatchesOnDisk.headOption.getOrElse(0L)
}
+ // TODO: [SPARK-50845]: Currently, deleteSchemaFiles is a no-op since thresholdBatchId
+ // is always -1
private def deleteSchemaFiles(thresholdBatchId: Long): Unit = {
+ if (thresholdBatchId <= 0) {
+ return
+ }
+ // StateSchemaV3 filenames are of the format {batchId}_{UUID}
+ // so we want to filter for files that do not have this format
val schemaFiles = fm.list(stateSchemaPath).sorted.map(_.getPath)
val filesBeforeThreshold = schemaFiles.filter { path =>
- val batchIdInPath = path.getName.split("_").head.toLong
- batchIdInPath <= thresholdBatchId
+ scala.util.Try(path.getName.split("_").head.toLong)
+ .toOption
+ .exists(_ <= thresholdBatchId)
}
filesBeforeThreshold.foreach { path =>
fm.delete(path)
@@ -459,8 +477,21 @@ class OperatorStateMetadataV2FileManager(
}
}
- // TODO: Implement state schema file purging logic once we have
- // enabled full-rewrite.
+ // TODO: [SPARK-50845]: Return earliest schema file we need after implementing
+ // full-rewrite
0
}
}
+
+/**
+ * Case class used to store additional properties for join operation.
+ * This is only used for unit tests, which verify that the properties in
+ * the corresponding OperatorStateMetadataV2 result are non-empty.
+ */
+case class StreamingJoinOperatorProperties(useVirtualColumnFamilies: Boolean) {
+ def json: String = {
+ val json =
+ JObject("useVirtualColumnFamilies" -> JBool(useVirtualColumnFamilies))
+ compact(render(json))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 820322d1e0ee1..6b3bec2077037 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -60,11 +60,11 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple
*
* @note This class is not thread-safe, so use it only from one thread.
* @see [[RocksDBFileManager]] to see how the files are laid out in local disk and DFS.
- * @param dfsRootDir Remote directory where checkpoints are going to be written
* @param conf Configuration for RocksDB
+ * @param stateStoreId StateStoreId for the state store
* @param localRootDir Root directory in local disk that is used to working and checkpointing dirs
* @param hadoopConf Hadoop configuration for talking to the remote file system
- * @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs
+ * @param eventForwarder The RocksDBEventForwarder object for reporting events to the coordinator
*/
class RocksDB(
dfsRootDir: String,
@@ -73,7 +73,9 @@ class RocksDB(
hadoopConf: Configuration = new Configuration,
loggingId: String = "",
useColumnFamilies: Boolean = false,
- enableStateStoreCheckpointIds: Boolean = false) extends Logging {
+ enableStateStoreCheckpointIds: Boolean = false,
+ partitionId: Int = 0,
+ eventForwarder: Option[RocksDBEventForwarder] = None) extends Logging {
import RocksDB._
@@ -135,7 +137,23 @@ class RocksDB(
private val nativeStats = rocksDbOptions.statistics()
private val workingDir = createTempDir("workingDir")
- private val fileManager = new RocksDBFileManager(dfsRootDir, createTempDir("fileManager"),
+
+ protected def createFileManager(
+ dfsRootDir: String,
+ localTempDir: File,
+ hadoopConf: Configuration,
+ codecName: String,
+ loggingId: String): RocksDBFileManager = {
+ new RocksDBFileManager(
+ dfsRootDir,
+ localTempDir,
+ hadoopConf,
+ codecName,
+ loggingId = loggingId
+ )
+ }
+
+ private[spark] val fileManager = createFileManager(dfsRootDir, createTempDir("fileManager"),
hadoopConf, conf.compressionCodec, loggingId = loggingId)
private val byteArrayPair = new ByteArrayPair()
private val commitLatencyMs = new mutable.HashMap[String, Long]()
@@ -267,7 +285,7 @@ class RocksDB(
* @return - true if the column family exists, false otherwise
*/
def checkColFamilyExists(colFamilyName: String): Boolean = {
- db != null && colFamilyNameToInfoMap.containsKey(colFamilyName)
+ colFamilyNameToInfoMap.containsKey(colFamilyName)
}
// This method sets the internal column family metadata to
@@ -387,6 +405,9 @@ class RocksDB(
// Initialize maxVersion upon successful load from DFS
fileManager.setMaxSeenVersion(version)
+ // Report this snapshot version to the coordinator
+ reportSnapshotUploadToCoordinator(latestSnapshotVersion)
+
openLocalRocksDB(metadata)
if (loadedVersion != version) {
@@ -464,6 +485,9 @@ class RocksDB(
// Initialize maxVersion upon successful load from DFS
fileManager.setMaxSeenVersion(version)
+ // Report this snapshot version to the coordinator
+ reportSnapshotUploadToCoordinator(latestSnapshotVersion)
+
openLocalRocksDB(metadata)
if (loadedVersion != version) {
@@ -532,11 +556,13 @@ class RocksDB(
maxColumnFamilyId.set(maxId)
}
+ openDB()
+ // Call this after opening the DB to ensure that forcing snapshot is not triggered
+ // unnecessarily.
if (useColumnFamilies) {
createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, isInternal = false)
}
- openDB()
val (numKeys, numInternalKeys) = {
if (!conf.trackTotalNumberOfRows) {
// we don't track the total number of rows - discard the number being track
@@ -599,6 +625,8 @@ class RocksDB(
loadedVersion = -1 // invalidate loaded data
throw t
}
+ // Report this snapshot version to the coordinator
+ reportSnapshotUploadToCoordinator(snapshotVersion)
this
}
@@ -671,16 +699,15 @@ class RocksDB(
if (useColumnFamilies) {
changelogReader.foreach { case (recordType, key, value) =>
- val (keyWithoutPrefix, cfName) = decodeStateRowWithPrefix(key)
recordType match {
case RecordType.PUT_RECORD =>
- put(keyWithoutPrefix, value, cfName)
+ put(key, value, includesPrefix = true)
case RecordType.DELETE_RECORD =>
- remove(keyWithoutPrefix, cfName)
+ remove(key, includesPrefix = true)
case RecordType.MERGE_RECORD =>
- merge(keyWithoutPrefix, value, cfName)
+ merge(key, value, includesPrefix = true)
}
}
} else {
@@ -801,8 +828,9 @@ class RocksDB(
def put(
key: Array[Byte],
value: Array[Byte],
- cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
- val keyWithPrefix = if (useColumnFamilies) {
+ cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
+ includesPrefix: Boolean = false): Unit = {
+ val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
encodeStateRowWithPrefix(key, cfName)
} else {
key
@@ -827,8 +855,9 @@ class RocksDB(
def merge(
key: Array[Byte],
value: Array[Byte],
- cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
- val keyWithPrefix = if (useColumnFamilies) {
+ cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
+ includesPrefix: Boolean = false): Unit = {
+ val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
encodeStateRowWithPrefix(key, cfName)
} else {
key
@@ -843,8 +872,11 @@ class RocksDB(
* Remove the key if present.
* @note This update is not committed to disk until commit() is called.
*/
- def remove(key: Array[Byte], cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
- val keyWithPrefix = if (useColumnFamilies) {
+ def remove(
+ key: Array[Byte],
+ cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
+ includesPrefix: Boolean = false): Unit = {
+ val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
encodeStateRowWithPrefix(key, cfName)
} else {
key
@@ -987,7 +1019,7 @@ class RocksDB(
* - Create a RocksDB checkpoint in a new local dir
* - Sync the checkpoint dir files to DFS
*/
- def commit(): Long = {
+ def commit(): (Long, StateStoreCheckpointInfo) = {
val newVersion = loadedVersion + 1
try {
logInfo(log"Flushing updates for ${MDC(LogKeys.VERSION_NUM, newVersion)}")
@@ -1056,7 +1088,7 @@ class RocksDB(
recordedMetrics = Some(metrics)
logInfo(log"Committed ${MDC(LogKeys.VERSION_NUM, newVersion)}, " +
log"stats = ${MDC(LogKeys.METRICS_JSON, recordedMetrics.get.json)}")
- loadedVersion
+ (loadedVersion, getLatestCheckpointInfo)
} catch {
case t: Throwable =>
loadedVersion = -1 // invalidate loaded version
@@ -1107,6 +1139,10 @@ class RocksDB(
val (dfsFileSuffix, immutableFileMapping) = rocksDBFileMapping.createSnapshotFileMapping(
fileManager, checkpointDir, version)
+ logInfo(log"RocksDB file mapping after creating snapshot file mapping for version " +
+ log"${MDC(LogKeys.VERSION_NUM, version)}:\n" +
+ log"${MDC(LogKeys.ROCKS_DB_FILE_MAPPING, rocksDBFileMapping)}")
+
val newSnapshot = Some(RocksDBSnapshot(
checkpointDir,
version,
@@ -1209,6 +1245,8 @@ class RocksDB(
}
silentDeleteRecursively(localRootDir, "closing RocksDB")
+ // Clear internal maps to reset the state
+ clearColFamilyMaps()
} catch {
case e: Exception =>
logWarning("Error closing RocksDB", e)
@@ -1224,11 +1262,11 @@ class RocksDB(
def getWriteBufferManagerAndCache(): (WriteBufferManager, Cache) = (writeBufferManager, lruCache)
/**
- * Called by RocksDBStateStoreProvider to retrieve the checkpoint information to be
+ * Called by commit() to retrieve the checkpoint information to be
* passed back to the stateful operator. It will return the information for the latest
* state store checkpointing.
*/
- def getLatestCheckpointInfo(partitionId: Int): StateStoreCheckpointInfo = {
+ private def getLatestCheckpointInfo: StateStoreCheckpointInfo = {
StateStoreCheckpointInfo(
partitionId,
loadedVersion,
@@ -1459,14 +1497,16 @@ class RocksDB(
// This is relative aggressive because that even if the uploading succeeds,
// it is not necessarily the one written to the commit log. But we can always load lineage
// from commit log so it is fine.
- lineageManager.resetLineage(lineageManager.getLineageForCurrVersion()
- .filter(i => i.version >= snapshot.version))
+ lineageManager.truncateFromVersion(snapshot.version)
logInfo(log"${MDC(LogKeys.LOG_ID, loggingId)}: " +
log"Upload snapshot of version ${MDC(LogKeys.VERSION_NUM, snapshot.version)}, " +
log"with uniqueId: ${MDC(LogKeys.UUID, snapshot.uniqueId)} " +
log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " +
log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}")
- lastUploadedSnapshotVersion.set(snapshot.version)
+ // Compare and update with the version that was just uploaded.
+ lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, v))
+ // Report snapshot upload event to the coordinator.
+ reportSnapshotUploadToCoordinator(snapshot.version)
} finally {
snapshot.close()
}
@@ -1474,17 +1514,28 @@ class RocksDB(
fileManagerMetrics
}
+ /** Reports to the coordinator with the event listener that a snapshot finished uploading */
+ private def reportSnapshotUploadToCoordinator(version: Long): Unit = {
+ if (conf.reportSnapshotUploadLag) {
+ // Note that we still report snapshot versions even when changelog checkpointing is disabled.
+ // The coordinator needs a way to determine whether upload messages are disabled or not,
+ // which would be different between RocksDB and HDFS stores due to changelog checkpointing.
+ eventForwarder.foreach(_.reportSnapshotUploaded(version))
+ }
+ }
+
/** Create a native RocksDB logger that forwards native logs to log4j with correct log levels. */
private def createLogger(): Logger = {
val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) {
override def log(infoLogLevel: InfoLogLevel, logMsg: String) = {
// Map DB log level to log4j levels
// Warn is mapped to info because RocksDB warn is too verbose
+ // Info is mapped to debug because RocksDB info is too verbose
// (e.g. dumps non-warning stuff like stats)
val loggingFunc: ( => LogEntry) => Unit = infoLogLevel match {
case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
- case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_)
- case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
+ case InfoLogLevel.WARN_LEVEL => logInfo(_)
+ case InfoLogLevel.INFO_LEVEL | InfoLogLevel.DEBUG_LEVEL => logDebug(_)
case _ => logTrace(_)
}
loggingFunc(log"[NativeRocksDB-${MDC(LogKeys.ROCKS_DB_LOG_LEVEL, infoLogLevel.getValue)}]" +
@@ -1573,6 +1624,16 @@ class RocksDBFileMapping {
// from reusing SST files which have not been yet persisted to DFS,
val snapshotsPendingUpload: Set[RocksDBVersionSnapshotInfo] = ConcurrentHashMap.newKeySet()
+ /**
+ * Clear everything stored in the file mapping.
+ */
+ def clear(): Unit = {
+ localFileMappings.clear()
+ snapshotsPendingUpload.clear()
+ }
+
+ override def toString: String = localFileMappings.toString()
+
/**
* Get the mapped DFS file for the given local file for a DFS load operation.
* If the currently mapped DFS file was mapped in the same or newer version as the version we
@@ -1589,14 +1650,21 @@ class RocksDBFileMapping {
fileManager: RocksDBFileManager,
localFileName: String,
versionToLoad: Long): Option[RocksDBImmutableFile] = {
- getDfsFileWithVersionCheck(fileManager, localFileName, _ >= versionToLoad)
+ getDfsFileWithIncompatibilityCheck(
+ fileManager,
+ localFileName,
+ // We can't reuse the current local file since it was added in the same or newer version
+ // as the version we want to load
+ (fileVersion, _) => fileVersion >= versionToLoad
+ )
}
/**
* Get the mapped DFS file for the given local file for a DFS save (i.e. checkpoint) operation.
* If the currently mapped DFS file was mapped in the same or newer version as the version we
- * want to save (or was generated in a version which has not been uploaded to DFS yet),
- * the mapped DFS file is ignored. In this scenario, the local mapping to this DFS file
+ * want to save (or was generated in a version which has not been uploaded to DFS yet)
+ * or the mapped dfs file isn't the same size as the local file,
+ * then the mapped DFS file is ignored. In this scenario, the local mapping to this DFS file
* will be cleared, and function will return None.
*
* @note If the file was added in current version (i.e. versionToSave - 1), we can reuse it.
@@ -1607,19 +1675,26 @@ class RocksDBFileMapping {
*/
private def getDfsFileForSave(
fileManager: RocksDBFileManager,
- localFileName: String,
+ localFile: File,
versionToSave: Long): Option[RocksDBImmutableFile] = {
- getDfsFileWithVersionCheck(fileManager, localFileName, _ >= versionToSave)
+ getDfsFileWithIncompatibilityCheck(
+ fileManager,
+ localFile.getName,
+ (dfsFileVersion, dfsFile) =>
+ // The DFS file is not the same as the file we want to save, either if
+ // the DFS file was added in the same or higher version, or the file size is different
+ dfsFileVersion >= versionToSave || dfsFile.sizeBytes != localFile.length()
+ )
}
- private def getDfsFileWithVersionCheck(
+ private def getDfsFileWithIncompatibilityCheck(
fileManager: RocksDBFileManager,
localFileName: String,
- isIncompatibleVersion: Long => Boolean): Option[RocksDBImmutableFile] = {
+ isIncompatible: (Long, RocksDBImmutableFile) => Boolean): Option[RocksDBImmutableFile] = {
localFileMappings.get(localFileName).map { case (dfsFileMappedVersion, dfsFile) =>
val dfsFileSuffix = fileManager.dfsFileSuffix(dfsFile)
val versionSnapshotInfo = RocksDBVersionSnapshotInfo(dfsFileMappedVersion, dfsFileSuffix)
- if (isIncompatibleVersion(dfsFileMappedVersion) ||
+ if (isIncompatible(dfsFileMappedVersion, dfsFile) ||
snapshotsPendingUpload.contains(versionSnapshotInfo)) {
// the mapped dfs file cannot be used, delete from mapping
remove(localFileName)
@@ -1661,7 +1736,7 @@ class RocksDBFileMapping {
val dfsFilesSuffix = UUID.randomUUID().toString
val snapshotFileMapping = localImmutableFiles.map { f =>
val localFileName = f.getName
- val existingDfsFile = getDfsFileForSave(fileManager, localFileName, version)
+ val existingDfsFile = getDfsFileForSave(fileManager, f, version)
val dfsFile = existingDfsFile.getOrElse {
val newDfsFileName = fileManager.newDFSFileName(localFileName, dfsFilesSuffix)
val newDfsFile = RocksDBImmutableFile(localFileName, newDfsFileName, sizeBytes = f.length())
@@ -1715,7 +1790,8 @@ case class RocksDBConf(
highPriorityPoolRatio: Double,
compressionCodec: String,
allowFAllocate: Boolean,
- compression: String)
+ compression: String,
+ reportSnapshotUploadLag: Boolean)
object RocksDBConf {
/** Common prefix of all confs in SQLConf that affects RocksDB */
@@ -1898,7 +1974,8 @@ object RocksDBConf {
getRatioConf(HIGH_PRIORITY_POOL_RATIO_CONF),
storeConf.compressionCodec,
getBooleanConf(ALLOW_FALLOCATE_CONF),
- getStringConf(COMPRESSION_CONF))
+ getStringConf(COMPRESSION_CONF),
+ storeConf.reportSnapshotUploadLag)
}
def apply(): RocksDBConf = apply(new StateStoreConf())
@@ -1974,27 +2051,33 @@ case class AcquiredThreadInfo(
private[sql] class RocksDBLineageManager {
@volatile private var lineage: Array[LineageItem] = Array.empty
- override def toString: String = lineage.map {
- case LineageItem(version, uuid) => s"$version: $uuid"
- }.mkString(" ")
+ override def toString: String = synchronized {
+ lineage.map {
+ case LineageItem(version, uuid) => s"$version: $uuid"
+ }.mkString(" ")
+ }
- def appendLineageItem(item: LineageItem): Unit = {
+ def appendLineageItem(item: LineageItem): Unit = synchronized {
lineage = lineage :+ item
}
- def resetLineage(newLineage: Array[LineageItem]): Unit = {
+ def truncateFromVersion(versionToKeep: Long): Unit = synchronized {
+ resetLineage(getLineageForCurrVersion().filter(i => i.version >= versionToKeep))
+ }
+
+ def resetLineage(newLineage: Array[LineageItem]): Unit = synchronized {
lineage = newLineage
}
- def getLineageForCurrVersion(): Array[LineageItem] = {
+ def getLineageForCurrVersion(): Array[LineageItem] = synchronized {
lineage.clone()
}
- def contains(item: LineageItem): Boolean = {
+ def contains(item: LineageItem): Boolean = synchronized {
lineage.contains(item)
}
- def clear(): Unit = {
+ def clear(): Unit = synchronized {
lineage = Array.empty
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index bb1198dfccafc..562a57aafbd41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.execution.streaming.state
import java.io.{File, FileInputStream, InputStream}
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Files
-import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{ZipEntry, ZipOutputStream}
import scala.collection.{mutable, Map}
+import scala.math._
import com.fasterxml.jackson.annotation.JsonInclude.Include
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
@@ -32,11 +32,11 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
import org.apache.commons.io.{FilenameUtils, IOUtils}
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, Path, PathFilter}
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
import org.json4s.{Formats, NoTypeHints}
import org.json4s.jackson.Serialization
-import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.internal.{Logging, LogKeys, MDC, MessageWithContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -133,8 +133,12 @@ class RocksDBFileManager(
import RocksDBImmutableFile._
+ protected def getFileSystem(myDfsRootDir: String, myHadoopConf: Configuration) : FileSystem = {
+ new Path(myDfsRootDir).getFileSystem(myHadoopConf)
+ }
+
private lazy val fm = CheckpointFileManager.create(new Path(dfsRootDir), hadoopConf)
- private val fs = new Path(dfsRootDir).getFileSystem(hadoopConf)
+ private val fs = getFileSystem(dfsRootDir, hadoopConf)
private val onlyZipFiles = new PathFilter {
override def accept(path: Path): Boolean = path.toString.endsWith(".zip")
}
@@ -178,40 +182,50 @@ class RocksDBFileManager(
checkpointUniqueId: Option[String] = None,
stateStoreCheckpointIdLineage: Option[Array[LineageItem]] = None
): StateStoreChangelogWriter = {
- val changelogFile = dfsChangelogFile(version, checkpointUniqueId)
- if (!rootDirChecked) {
- val rootDir = new Path(dfsRootDir)
- if (!fm.exists(rootDir)) fm.mkdirs(rootDir)
- rootDirChecked = true
- }
+ try {
+ val changelogFile = dfsChangelogFile(version, checkpointUniqueId)
+ if (!rootDirChecked) {
+ val rootDir = new Path(dfsRootDir)
+ if (!fm.exists(rootDir)) fm.mkdirs(rootDir)
+ rootDirChecked = true
+ }
- val enableStateStoreCheckpointIds = checkpointUniqueId.isDefined
- val changelogVersion = getChangelogWriterVersion(
- useColumnFamilies, enableStateStoreCheckpointIds)
-
- val changelogWriter = changelogVersion match {
- case 1 =>
- new StateStoreChangelogWriterV1(fm, changelogFile, codec)
- case 2 =>
- new StateStoreChangelogWriterV2(fm, changelogFile, codec)
- case 3 =>
- assert(enableStateStoreCheckpointIds && stateStoreCheckpointIdLineage.isDefined,
- "StateStoreChangelogWriterV3 should only be initialized when " +
- "state store checkpoint unique id is enabled")
- new StateStoreChangelogWriterV3(fm, changelogFile, codec, stateStoreCheckpointIdLineage.get)
- case 4 =>
- assert(enableStateStoreCheckpointIds && stateStoreCheckpointIdLineage.isDefined,
- "StateStoreChangelogWriterV4 should only be initialized when " +
- "state store checkpoint unique id is enabled")
- new StateStoreChangelogWriterV4(fm, changelogFile, codec, stateStoreCheckpointIdLineage.get)
- case _ =>
- throw QueryExecutionErrors.invalidChangeLogWriterVersion(changelogVersion)
- }
+ val enableStateStoreCheckpointIds = checkpointUniqueId.isDefined
+ val changelogVersion = getChangelogWriterVersion(
+ useColumnFamilies, enableStateStoreCheckpointIds)
+
+ val changelogWriter = changelogVersion match {
+ case 1 =>
+ new StateStoreChangelogWriterV1(fm, changelogFile, codec)
+ case 2 =>
+ new StateStoreChangelogWriterV2(fm, changelogFile, codec)
+ case 3 =>
+ assert(enableStateStoreCheckpointIds && stateStoreCheckpointIdLineage.isDefined,
+ "StateStoreChangelogWriterV3 should only be initialized when " +
+ "state store checkpoint unique id is enabled")
+ new StateStoreChangelogWriterV3(fm, changelogFile, codec,
+ stateStoreCheckpointIdLineage.get)
+ case 4 =>
+ assert(enableStateStoreCheckpointIds && stateStoreCheckpointIdLineage.isDefined,
+ "StateStoreChangelogWriterV4 should only be initialized when " +
+ "state store checkpoint unique id is enabled")
+ new StateStoreChangelogWriterV4(fm, changelogFile, codec,
+ stateStoreCheckpointIdLineage.get)
+ case _ =>
+ throw QueryExecutionErrors.invalidChangeLogWriterVersion(changelogVersion)
+ }
- logInfo(log"Loaded change log reader version " +
- log"${MDC(LogKeys.FILE_VERSION, changelogWriter.version)}")
+ logInfo(log"Loaded change log reader version " +
+ log"${MDC(LogKeys.FILE_VERSION, changelogWriter.version)}")
- changelogWriter
+ changelogWriter
+ } catch {
+ case e: SparkException
+ if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) =>
+ throw e
+ case e: Throwable =>
+ throw StateStoreErrors.failedToGetChangelogWriter(version, e)
+ }
}
// Get the changelog file at version
@@ -281,7 +295,7 @@ class RocksDBFileManager(
colFamilyIdMapping, colFamilyTypeMapping, maxColumnFamilyId)
val metadataFile = localMetadataFile(checkpointDir)
metadata.writeToFile(metadataFile)
- logInfo(log"Written metadata for version ${MDC(LogKeys.VERSION_NUM, version)}:\n" +
+ logDebug(log"Written metadata for version ${MDC(LogKeys.VERSION_NUM, version)}:\n" +
log"${MDC(LogKeys.METADATA_JSON, metadata.prettyJson)}")
if (version <= 1 && numKeys <= 0) {
@@ -322,6 +336,8 @@ class RocksDBFileManager(
val metadata = if (version == 0) {
if (localDir.exists) Utils.deleteRecursively(localDir)
localDir.mkdirs()
+ // Since we cleared the local dir, we should also clear the local file mapping
+ rocksDBFileMapping.clear()
RocksDBCheckpointMetadata(Seq.empty, 0)
} else {
// Delete all non-immutable files in local dir, and unzip new ones from DFS commit file
@@ -331,7 +347,7 @@ class RocksDBFileManager(
// Copy the necessary immutable files
val metadataFile = localMetadataFile(localDir)
val metadata = RocksDBCheckpointMetadata.readFromFile(metadataFile)
- logInfo(log"Read metadata for version ${MDC(LogKeys.VERSION_NUM, version)}:\n" +
+ logDebug(log"Read metadata for version ${MDC(LogKeys.VERSION_NUM, version)}:\n" +
log"${MDC(LogKeys.METADATA_JSON, metadata.prettyJson)}")
loadImmutableFilesFromDfs(metadata.immutableFiles, localDir, rocksDBFileMapping, version)
versionToRocksDBFiles.put((version, checkpointUniqueId), metadata.immutableFiles)
@@ -340,6 +356,10 @@ class RocksDBFileManager(
}
logFilesInDir(localDir, log"Loaded checkpoint files " +
log"for version ${MDC(LogKeys.VERSION_NUM, version)}")
+ logInfo(log"RocksDB file mapping after loading checkpoint version " +
+ log"${MDC(LogKeys.VERSION_NUM, version)} from DFS:\n" +
+ log"${MDC(LogKeys.ROCKS_DB_FILE_MAPPING, rocksDBFileMapping)}")
+
metadata
}
@@ -507,7 +527,9 @@ class RocksDBFileManager(
logInfo(log"Estimated maximum version is " +
log"${MDC(LogKeys.MAX_SEEN_VERSION, maxSeenVersion.get)}" +
log" and minimum version is ${MDC(LogKeys.MIN_SEEN_VERSION, minSeenVersion)}")
- val versionsToDelete = maxSeenVersion.get - minSeenVersion + 1 - numVersionsToRetain
+ // If the number of versions to delete is negative, that means that none of the versions
+ // are eligible for deletion and we set the variable to 0
+ val versionsToDelete = max(maxSeenVersion.get - minSeenVersion + 1 - numVersionsToRetain, 0)
if (versionsToDelete < minVersionsToDelete) {
logInfo(log"Skipping deleting files." +
log" Need at least ${MDC(LogKeys.MIN_VERSIONS_TO_DELETE, minVersionsToDelete)}" +
@@ -833,7 +855,7 @@ class RocksDBFileManager(
totalBytes += bytes
}
zout.close() // so that any error in closing also cancels the output stream
- logInfo(log"Zipped ${MDC(LogKeys.NUM_BYTES, totalBytes)} bytes (before compression) to " +
+ logDebug(log"Zipped ${MDC(LogKeys.NUM_BYTES, totalBytes)} bytes (before compression) to " +
log"${MDC(LogKeys.FILE_NAME, filesStr)}")
// The other fields saveCheckpointMetrics should have been filled
saveCheckpointMetrics =
@@ -856,16 +878,10 @@ class RocksDBFileManager(
lazy val files = Option(Utils.recursiveList(dir)).getOrElse(Array.empty).map { f =>
s"${f.getAbsolutePath} - ${f.length()} bytes"
}
- logInfo(msg + log" - ${MDC(LogKeys.NUM_FILES, files.length)} files\n\t" +
+ logDebug(msg + log" - ${MDC(LogKeys.NUM_FILES, files.length)} files\n\t" +
log"${MDC(LogKeys.FILE_NAME, files.mkString("\n\t"))}")
}
- private def newDFSFileName(localFileName: String): String = {
- val baseName = FilenameUtils.getBaseName(localFileName)
- val extension = FilenameUtils.getExtension(localFileName)
- s"$baseName-${UUID.randomUUID}.$extension"
- }
-
def newDFSFileName(localFileName: String, dfsFileSuffix: String): String = {
val baseName = FilenameUtils.getBaseName(localFileName)
val extension = FilenameUtils.getExtension(localFileName)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 47721cea4359f..6a36b8c015196 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -32,7 +32,7 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution}
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.state.StateStoreEncoding.Avro
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.Platform
@@ -67,7 +67,7 @@ private[sql] class RocksDBStateStoreProvider
verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal)
val cfId = rocksDB.createColFamilyIfAbsent(colFamilyName, isInternal)
val dataEncoderCacheKey = StateRowEncoderCacheKey(
- queryRunId = getRunId(hadoopConf),
+ queryRunId = StateStoreProvider.getRunId(hadoopConf),
operatorId = stateStoreId.operatorId,
partitionId = stateStoreId.partitionId,
stateStoreName = stateStoreId.storeName,
@@ -230,10 +230,12 @@ private[sql] class RocksDBStateStoreProvider
}
}
+ var checkpointInfo: Option[StateStoreCheckpointInfo] = None
override def commit(): Long = synchronized {
try {
verify(state == UPDATING, "Cannot commit after already committed or aborted")
- val newVersion = rocksDB.commit()
+ val (newVersion, newCheckpointInfo) = rocksDB.commit()
+ checkpointInfo = Some(newCheckpointInfo)
state = COMMITTED
logInfo(log"Committed ${MDC(VERSION_NUM, newVersion)} " +
log"for ${MDC(STATE_STORE_ID, id)}")
@@ -325,7 +327,8 @@ private[sql] class RocksDBStateStoreProvider
rocksDBMetrics.numUncommittedKeys,
rocksDBMetrics.totalMemUsageBytes,
stateStoreCustomMetrics,
- stateStoreInstanceMetrics)
+ stateStoreInstanceMetrics
+ )
} else {
logInfo(log"Failed to collect metrics for store_id=${MDC(STATE_STORE_ID, id)} " +
log"and version=${MDC(VERSION_NUM, version)}")
@@ -334,8 +337,11 @@ private[sql] class RocksDBStateStoreProvider
}
override def getStateStoreCheckpointInfo(): StateStoreCheckpointInfo = {
- val checkpointInfo = rocksDB.getLatestCheckpointInfo(id.partitionId)
- checkpointInfo
+ checkpointInfo match {
+ case Some(info) => info
+ case None => throw StateStoreErrors.stateStoreOperationOutOfOrder(
+ "Cannot get checkpointInfo without committing the store")
+ }
}
override def hasCommitted: Boolean = state == COMMITTED
@@ -384,6 +390,8 @@ private[sql] class RocksDBStateStoreProvider
this.useColumnFamilies = useColumnFamilies
this.stateStoreEncoding = storeConf.stateStoreEncodingFormat
this.stateSchemaProvider = stateSchemaProvider
+ this.rocksDBEventForwarder =
+ Some(RocksDBEventForwarder(StateStoreProvider.getRunId(hadoopConf), stateStoreId))
if (useMultipleValuesPerKey) {
require(useColumnFamilies, "Multiple values per key support requires column families to be" +
@@ -393,7 +401,7 @@ private[sql] class RocksDBStateStoreProvider
rocksDB // lazy initialization
val dataEncoderCacheKey = StateRowEncoderCacheKey(
- queryRunId = getRunId(hadoopConf),
+ queryRunId = StateStoreProvider.getRunId(hadoopConf),
operatorId = stateStoreId.operatorId,
partitionId = stateStoreId.partitionId,
stateStoreName = stateStoreId.storeName,
@@ -517,6 +525,29 @@ private[sql] class RocksDBStateStoreProvider
@volatile private var useColumnFamilies: Boolean = _
@volatile private var stateStoreEncoding: String = _
@volatile private var stateSchemaProvider: Option[StateSchemaProvider] = _
+ @volatile private var rocksDBEventForwarder: Option[RocksDBEventForwarder] = _
+
+ protected def createRocksDB(
+ dfsRootDir: String,
+ conf: RocksDBConf,
+ localRootDir: File,
+ hadoopConf: Configuration,
+ loggingId: String,
+ useColumnFamilies: Boolean,
+ enableStateStoreCheckpointIds: Boolean,
+ partitionId: Int = 0,
+ eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = {
+ new RocksDB(
+ dfsRootDir,
+ conf,
+ localRootDir,
+ hadoopConf,
+ loggingId,
+ useColumnFamilies,
+ enableStateStoreCheckpointIds,
+ partitionId,
+ eventForwarder)
+ }
private[sql] lazy val rocksDB = {
val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
@@ -524,8 +555,9 @@ private[sql] class RocksDBStateStoreProvider
s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
- new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr,
- useColumnFamilies, storeConf.enableStateStoreCheckpointIds)
+ createRocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr,
+ useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId,
+ rocksDBEventForwarder)
}
private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String,
@@ -796,16 +828,6 @@ object RocksDBStateStoreProvider {
)
}
- private def getRunId(hadoopConf: Configuration): String = {
- val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY)
- if (runId != null) {
- runId
- } else {
- assert(Utils.isTesting, "Failed to find query id/batch Id in task context")
- UUID.randomUUID().toString
- }
- }
-
// Native operation latencies report as latency in microseconds
// as SQLMetrics support millis. Convert the value to millis
val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
@@ -965,3 +987,33 @@ class RocksDBStateStoreChangeDataReader(
}
}
}
+
+/**
+ * Class used to relay events reported from a RocksDB instance to the state store coordinator.
+ *
+ * We pass this into the RocksDB instance to report specific events like snapshot uploads.
+ * This should only be used to report back to the coordinator for metrics and monitoring purposes.
+ */
+private[state] case class RocksDBEventForwarder(queryRunId: String, stateStoreId: StateStoreId) {
+ // Build the state store provider ID from the query run ID and the state store ID
+ private val providerId = StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId))
+
+ /**
+ * Callback function from RocksDB to report events to the coordinator.
+ * Information from the store provider such as the state store ID and query run ID are
+ * attached here to report back to the coordinator.
+ *
+ * @param version The snapshot version that was just uploaded from RocksDB
+ */
+ def reportSnapshotUploaded(version: Long): Unit = {
+ // Report the state store provider ID and the version to the coordinator
+ val currentTimestamp = System.currentTimeMillis()
+ StateStoreProvider.coordinatorRef.foreach(
+ _.snapshotUploaded(
+ providerId,
+ version,
+ currentTimestamp
+ )
+ )
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 09acc24aff982..ffaba5ef1502f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -593,7 +593,15 @@ trait StateStoreProvider {
def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] = Seq.empty
}
-object StateStoreProvider {
+object StateStoreProvider extends Logging {
+
+ /**
+ * The state store coordinator reference used to report events such as snapshot uploads from
+ * the state store providers.
+ * For all other messages, refer to the coordinator reference in the [[StateStore]] object.
+ */
+ @GuardedBy("this")
+ private var stateStoreCoordinatorRef: StateStoreCoordinatorRef = _
/**
* Return a instance of the given provider class name. The instance will not be initialized.
@@ -652,6 +660,47 @@ object StateStoreProvider {
}
}
}
+
+ /**
+ * Get the runId from the provided hadoopConf. If it is not found, generate a random UUID.
+ *
+ * @param hadoopConf Hadoop configuration used by the StateStore to save state data
+ */
+ private[state] def getRunId(hadoopConf: Configuration): String = {
+ val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY)
+ if (runId != null) {
+ runId
+ } else {
+ assert(Utils.isTesting, "Failed to find query id/batch Id in task context")
+ UUID.randomUUID().toString
+ }
+ }
+
+ /**
+ * Create the state store coordinator reference which will be reused across state store providers
+ * in the executor.
+ * This coordinator reference should only be used to report events from store providers regarding
+ * snapshot uploads to avoid lock contention with other coordinator RPC messages.
+ */
+ private[state] def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized {
+ val env = SparkEnv.get
+ if (env != null) {
+ val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER
+ // If running locally, then the coordinator reference in stateStoreCoordinatorRef may have
+ // become inactive as SparkContext + SparkEnv may have been restarted. Hence, when running in
+ // driver, always recreate the reference.
+ if (isDriver || stateStoreCoordinatorRef == null) {
+ logDebug("Getting StateStoreCoordinatorRef")
+ stateStoreCoordinatorRef = StateStoreCoordinatorRef.forExecutor(env)
+ }
+ logInfo(log"Retrieved reference to StateStoreCoordinator: " +
+ log"${MDC(LogKeys.STATE_STORE_COORDINATOR, stateStoreCoordinatorRef)}")
+ Some(stateStoreCoordinatorRef)
+ } else {
+ stateStoreCoordinatorRef = null
+ None
+ }
+ }
}
/**
@@ -838,7 +887,9 @@ object StateStore extends Logging {
* Thread Pool that runs maintenance on partitions that are scheduled by
* MaintenanceTask periodically
*/
- class MaintenanceThreadPool(numThreads: Int) {
+ class MaintenanceThreadPool(
+ numThreads: Int,
+ shutdownTimeout: Long) {
private val threadPool = ThreadUtils.newDaemonFixedThreadPool(
numThreads, "state-store-maintenance-thread")
@@ -851,10 +902,11 @@ object StateStore extends Logging {
threadPool.shutdown() // Disable new tasks from being submitted
// Wait a while for existing tasks to terminate
- if (!threadPool.awaitTermination(5 * 60, TimeUnit.SECONDS)) {
+ if (!threadPool.awaitTermination(shutdownTimeout, TimeUnit.SECONDS)) {
logWarning(
- s"MaintenanceThreadPool is not able to be terminated within 300 seconds," +
- " forcefully shutting down now.")
+ log"MaintenanceThreadPool failed to terminate within " +
+ log"waitTimeout=${MDC(LogKeys.TIMEOUT, shutdownTimeout)} seconds, " +
+ log"forcefully shutting down now.")
threadPool.shutdownNow() // Cancel currently executing tasks
// Wait a while for tasks to respond to being cancelled
@@ -955,13 +1007,29 @@ object StateStore extends Logging {
log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, storeProviderId.queryRunId)}")
}
- val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
- val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
- providerIdsToUnload.foreach(unload(_))
+ // Only tell the state store coordinator we are active if we will remain active
+ // after the task. When we unload after committing, there's no need for the coordinator
+ // to track which executor has which provider
+ if (!storeConf.unloadOnCommit) {
+ val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
+ val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
+ providerIdsToUnload.foreach(unload(_))
+ }
+
provider
}
}
+ /** Runs maintenance and then unload a state store provider */
+ def doMaintenanceAndUnload(storeProviderId: StateStoreProviderId): Unit = {
+ loadedProviders.synchronized {
+ loadedProviders.remove(storeProviderId)
+ }.foreach { provider =>
+ provider.doMaintenance()
+ provider.close()
+ }
+ }
+
/** Unload a state store provider */
def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized {
loadedProviders.remove(storeProviderId).foreach(_.close())
@@ -985,6 +1053,14 @@ object StateStore extends Logging {
/** Stop maintenance thread and reset the maintenance task */
def stopMaintenanceTask(): Unit = loadedProviders.synchronized {
+ stopMaintenanceTaskWithoutLock()
+ }
+
+ /**
+ * Only used for unit tests. The function doesn't hold loadedProviders lock. Calling
+ * it can work-around a deadlock condition where a maintenance task is waiting for the lock
+ * */
+ private[streaming] def stopMaintenanceTaskWithoutLock(): Unit = {
if (maintenanceThreadPool != null) {
maintenanceThreadPoolLock.synchronized {
maintenancePartitions.clear()
@@ -1010,13 +1086,15 @@ object StateStore extends Logging {
/** Start the periodic maintenance task if not already started and if Spark active */
private def startMaintenanceIfNeeded(storeConf: StateStoreConf): Unit = {
val numMaintenanceThreads = storeConf.numStateStoreMaintenanceThreads
+ val maintenanceShutdownTimeout = storeConf.stateStoreMaintenanceShutdownTimeout
loadedProviders.synchronized {
- if (SparkEnv.get != null && !isMaintenanceRunning) {
+ if (SparkEnv.get != null && !isMaintenanceRunning && !storeConf.unloadOnCommit) {
maintenanceTask = new MaintenanceTask(
storeConf.maintenanceInterval,
task = { doMaintenance() }
)
- maintenanceThreadPool = new MaintenanceThreadPool(numMaintenanceThreads)
+ maintenanceThreadPool = new MaintenanceThreadPool(numMaintenanceThreads,
+ maintenanceShutdownTimeout)
logInfo("State Store maintenance task started")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
index b4fbb5560f2f4..bcaff4c60d08f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
@@ -368,7 +368,10 @@ class StateStoreChangelogReaderFactory(
// When there is no record being written in the changelog file in V1,
// the file contains a single int -1 meaning EOF, then the above readUTF()
// throws with EOFException and we return version 1.
- case _: java.io.EOFException => 1
+ // Or if the first record in the changelog file in V1 has a large enough
+ // key, readUTF() will throw a UTFDataFormatException so we should return
+ // version 1 (SPARK-51922).
+ case _: java.io.EOFException | _: java.io.UTFDataFormatException => 1
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index 9d26bf8fdf2e7..9a994200baeb0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -33,6 +33,11 @@ class StateStoreConf(
*/
val numStateStoreMaintenanceThreads: Int = sqlConf.numStateStoreMaintenanceThreads
+ /**
+ * Timeout for state store maintenance operations to complete on shutdown
+ */
+ val stateStoreMaintenanceShutdownTimeout: Long = sqlConf.stateStoreMaintenanceShutdownTimeout
+
/**
* Minimum number of delta files in a chain after which HDFSBackedStateStore will
* consider generating a snapshot.
@@ -92,6 +97,15 @@ class StateStoreConf(
val enableStateStoreCheckpointIds =
StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sqlConf)
+ /**
+ * Whether the coordinator is reporting state stores trailing behind in snapshot uploads.
+ */
+ val reportSnapshotUploadLag: Boolean =
+ sqlConf.stateStoreCoordinatorReportSnapshotUploadLag
+
+ /** Whether to unload the store on task completion. */
+ val unloadOnCommit = sqlConf.stateStoreUnloadOnCommit
+
/**
* Additional configurations related to state store. This will capture all configs in
* SQLConf that start with `spark.sql.streaming.stateStore.`
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
index 84b77efea3caf..903f27fb2a223 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -22,9 +22,10 @@ import java.util.UUID
import scala.collection.mutable
import org.apache.spark.SparkEnv
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.RpcUtils
/** Trait representing all messages to [[StateStoreCoordinator]] */
@@ -55,6 +56,45 @@ private case class GetLocation(storeId: StateStoreProviderId)
private case class DeactivateInstances(runId: UUID)
extends StateStoreCoordinatorMessage
+/**
+ * This message is used to report a state store has just finished uploading a snapshot,
+ * along with the timestamp in milliseconds and the snapshot version.
+ */
+private case class ReportSnapshotUploaded(
+ providerId: StateStoreProviderId,
+ version: Long,
+ timestamp: Long)
+ extends StateStoreCoordinatorMessage
+
+/**
+ * This message is used for the coordinator to look for all state stores that are lagging behind
+ * in snapshot uploads. The coordinator will then log a warning message for each lagging instance.
+ */
+private case class LogLaggingStateStores(
+ queryRunId: UUID,
+ latestVersion: Long,
+ isTerminatingTrigger: Boolean)
+ extends StateStoreCoordinatorMessage
+
+/**
+ * Message used for testing.
+ * This message is used to retrieve the latest snapshot version reported for upload from a
+ * specific state store.
+ */
+private case class GetLatestSnapshotVersionForTesting(providerId: StateStoreProviderId)
+ extends StateStoreCoordinatorMessage
+
+/**
+ * Message used for testing.
+ * This message is used to retrieve all active state store instances falling behind in
+ * snapshot uploads, using version and time criteria.
+ */
+private case class GetLaggingStoresForTesting(
+ queryRunId: UUID,
+ latestVersion: Long,
+ isTerminatingTrigger: Boolean)
+ extends StateStoreCoordinatorMessage
+
private object StopCoordinator
extends StateStoreCoordinatorMessage
@@ -66,9 +106,9 @@ object StateStoreCoordinatorRef extends Logging {
/**
* Create a reference to a [[StateStoreCoordinator]]
*/
- def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
+ def forDriver(env: SparkEnv, sqlConf: SQLConf): StateStoreCoordinatorRef = synchronized {
try {
- val coordinator = new StateStoreCoordinator(env.rpcEnv)
+ val coordinator = new StateStoreCoordinator(env.rpcEnv, sqlConf)
val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator)
logInfo("Registered StateStoreCoordinator endpoint")
new StateStoreCoordinatorRef(coordinatorRef)
@@ -119,6 +159,46 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId))
}
+ /** Inform that an executor has uploaded a snapshot */
+ private[sql] def snapshotUploaded(
+ providerId: StateStoreProviderId,
+ version: Long,
+ timestamp: Long): Boolean = {
+ rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(providerId, version, timestamp))
+ }
+
+ /** Ask the coordinator to log all state store instances that are lagging behind in uploads */
+ private[sql] def logLaggingStateStores(
+ queryRunId: UUID,
+ latestVersion: Long,
+ isTerminatingTrigger: Boolean): Boolean = {
+ rpcEndpointRef.askSync[Boolean](
+ LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger))
+ }
+
+ /**
+ * Endpoint used for testing.
+ * Get the latest snapshot version uploaded for a state store.
+ */
+ private[state] def getLatestSnapshotVersionForTesting(
+ providerId: StateStoreProviderId): Option[Long] = {
+ rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersionForTesting(providerId))
+ }
+
+ /**
+ * Endpoint used for testing.
+ * Get the state store instances that are falling behind in snapshot uploads for a particular
+ * query run.
+ */
+ private[state] def getLaggingStoresForTesting(
+ queryRunId: UUID,
+ latestVersion: Long,
+ isTerminatingTrigger: Boolean = false): Seq[StateStoreProviderId] = {
+ rpcEndpointRef.askSync[Seq[StateStoreProviderId]](
+ GetLaggingStoresForTesting(queryRunId, latestVersion, isTerminatingTrigger)
+ )
+ }
+
private[state] def stop(): Unit = {
rpcEndpointRef.askSync[Boolean](StopCoordinator)
}
@@ -129,10 +209,30 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
* Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster,
* and get their locations for job scheduling.
*/
-private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
- extends ThreadSafeRpcEndpoint with Logging {
+private class StateStoreCoordinator(
+ override val rpcEnv: RpcEnv,
+ val sqlConf: SQLConf)
+ extends ThreadSafeRpcEndpoint with Logging {
private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation]
+ // Stores the latest snapshot upload event for a specific state store
+ private val stateStoreLatestUploadedSnapshot =
+ new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent]
+
+ // Default snapshot upload event to use when a provider has never uploaded a snapshot
+ private val defaultSnapshotUploadEvent = SnapshotUploadEvent(0, 0)
+
+ // Stores the last timestamp in milliseconds for each queryRunId indicating when the
+ // coordinator did a report on instances lagging behind on snapshot uploads.
+ // The initial timestamp is defaulted to 0 milliseconds.
+ private val lastFullSnapshotLagReportTimeMs = new mutable.HashMap[UUID, Long]
+
+ private def shouldCoordinatorReportSnapshotLag: Boolean =
+ sqlConf.stateStoreCoordinatorReportSnapshotUploadLag
+
+ private def coordinatorLagReportInterval: Long =
+ sqlConf.stateStoreCoordinatorSnapshotLagReportInterval
+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case ReportActiveInstance(id, host, executorId, providerIdsToCheck) =>
logDebug(s"Reported state store $id is active at $executorId")
@@ -164,13 +264,160 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
val storeIdsToRemove =
instances.keys.filter(_.queryRunId == runId).toSeq
instances --= storeIdsToRemove
+ // Also remove these instances from snapshot upload event tracking
+ stateStoreLatestUploadedSnapshot --= storeIdsToRemove
+ // Remove the corresponding run id entries for report time and starting time
+ lastFullSnapshotLagReportTimeMs -= runId
logDebug(s"Deactivating instances related to checkpoint location $runId: " +
storeIdsToRemove.mkString(", "))
context.reply(true)
+ case ReportSnapshotUploaded(providerId, version, timestamp) =>
+ // Ignore this upload event if the registered latest version for the store is more recent,
+ // since it's possible that an older version gets uploaded after a new executor uploads for
+ // the same state store but with a newer snapshot.
+ logDebug(s"Snapshot version $version was uploaded for state store $providerId")
+ if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version >= version)) {
+ stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp))
+ }
+ context.reply(true)
+
+ case LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger) =>
+ val currentTimestamp = System.currentTimeMillis()
+ // Only log lagging instances if snapshot lag reporting and uploading is enabled,
+ // otherwise all instances will be considered lagging.
+ if (shouldCoordinatorReportSnapshotLag) {
+ val laggingStores =
+ findLaggingStores(queryRunId, latestVersion, currentTimestamp, isTerminatingTrigger)
+ if (laggingStores.nonEmpty) {
+ logWarning(
+ log"StateStoreCoordinator Snapshot Lag Report for " +
+ log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+ log"Number of state stores falling behind: " +
+ log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}"
+ )
+ // Report all stores that are behind in snapshot uploads.
+ // Only report the list of providers lagging behind if the last reported time
+ // is not recent for this query run. The lag report interval denotes the minimum
+ // time between these full reports.
+ val timeSinceLastReport =
+ currentTimestamp - lastFullSnapshotLagReportTimeMs.getOrElse(queryRunId, 0L)
+ if (timeSinceLastReport > coordinatorLagReportInterval) {
+ // Mark timestamp of the report and log the lagging instances
+ lastFullSnapshotLagReportTimeMs.put(queryRunId, currentTimestamp)
+ // Only report the stores that are lagging the most behind in snapshot uploads.
+ laggingStores
+ .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, defaultSnapshotUploadEvent))
+ .take(sqlConf.stateStoreCoordinatorMaxLaggingStoresToReport)
+ .foreach { providerId =>
+ val baseLogMessage =
+ log"StateStoreCoordinator Snapshot Lag Detected for " +
+ log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+ log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, providerId.storeId)} " +
+ log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}"
+
+ val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match {
+ case Some(snapshotEvent) =>
+ val versionDelta = latestVersion - snapshotEvent.version
+ val timeDelta = currentTimestamp - snapshotEvent.timestamp
+
+ baseLogMessage + log", " +
+ log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " +
+ log"version delta: " +
+ log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " +
+ log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)"
+ case None =>
+ baseLogMessage + log", latest snapshot: no upload for query run)"
+ }
+ logWarning(logMessage)
+ }
+ }
+ }
+ }
+ context.reply(true)
+
+ case GetLatestSnapshotVersionForTesting(providerId) =>
+ val version = stateStoreLatestUploadedSnapshot.get(providerId).map(_.version)
+ logDebug(s"Got latest snapshot version of the state store $providerId: $version")
+ context.reply(version)
+
+ case GetLaggingStoresForTesting(queryRunId, latestVersion, isTerminatingTrigger) =>
+ val currentTimestamp = System.currentTimeMillis()
+ // Only report if snapshot lag reporting is enabled
+ if (shouldCoordinatorReportSnapshotLag) {
+ val laggingStores =
+ findLaggingStores(queryRunId, latestVersion, currentTimestamp, isTerminatingTrigger)
+ logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}")
+ context.reply(laggingStores)
+ } else {
+ context.reply(Seq.empty)
+ }
+
case StopCoordinator =>
stop() // Stop before replying to ensure that endpoint name has been deregistered
logInfo("StateStoreCoordinator stopped")
context.reply(true)
}
+
+ private def findLaggingStores(
+ queryRunId: UUID,
+ referenceVersion: Long,
+ referenceTimestamp: Long,
+ isTerminatingTrigger: Boolean): Seq[StateStoreProviderId] = {
+ // Determine alert thresholds from configurations for both time and version differences.
+ val snapshotVersionDeltaMultiplier =
+ sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog
+ val maintenanceIntervalMultiplier = sqlConf.stateStoreCoordinatorMultiplierForMinTimeDiffToLog
+ val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot
+ val maintenanceInterval = sqlConf.streamingMaintenanceInterval
+
+ // Use the configured multipliers multiplierForMinVersionDiffToLog and
+ // multiplierForMinTimeDiffToLog to determine the proper alert thresholds.
+ val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot
+ val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval
+
+ // Look for active state store providers that are lagging behind in snapshot uploads.
+ // The coordinator should only consider providers that are part of this specific query run.
+ instances.view.keys
+ .filter(_.queryRunId == queryRunId)
+ .filter { storeProviderId =>
+ // Stores that didn't upload a snapshot will be treated as a store with a snapshot of
+ // version 0 and timestamp 0ms.
+ val latestSnapshot = stateStoreLatestUploadedSnapshot.getOrElse(
+ storeProviderId,
+ defaultSnapshotUploadEvent
+ )
+ // Mark a state store as lagging if it's behind in both version and time.
+ // A state store is considered lagging if it's behind in both version and time according
+ // to the configured thresholds.
+ val isBehindOnVersions =
+ referenceVersion - latestSnapshot.version > minVersionDeltaForLogging
+ val isBehindOnTime =
+ referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging
+ // If the query is using a trigger that self-terminates like OneTimeTrigger
+ // and AvailableNowTrigger, we ignore the time threshold check as the upload frequency
+ // is not fully dependent on the maintenance interval.
+ isBehindOnVersions && (isTerminatingTrigger || isBehindOnTime)
+ }.toSeq
+ }
+}
+
+case class SnapshotUploadEvent(
+ version: Long,
+ timestamp: Long
+) extends Ordered[SnapshotUploadEvent] {
+
+ override def compare(otherEvent: SnapshotUploadEvent): Int = {
+ // Compare by version first, then by timestamp as tiebreaker
+ val versionCompare = this.version.compare(otherEvent.version)
+ if (versionCompare == 0) {
+ this.timestamp.compare(otherEvent.timestamp)
+ } else {
+ versionCompare
+ }
+ }
+
+ override def toString(): String = {
+ s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)"
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index 188306e82f688..b3bfce752fcf6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -212,6 +212,15 @@ object StateStoreErrors {
StateStoreInvalidVariableTypeChange = {
new StateStoreInvalidVariableTypeChange(stateName, oldType, newType)
}
+
+ def failedToGetChangelogWriter(version: Long, e: Throwable):
+ StateStoreFailedToGetChangelogWriter = {
+ new StateStoreFailedToGetChangelogWriter(version, e)
+ }
+
+ def stateStoreOperationOutOfOrder(errorMsg: String): StateStoreOperationOutOfOrder = {
+ new StateStoreOperationOutOfOrder(errorMsg)
+ }
}
class StateStoreDuplicateStateVariableDefined(stateVarName: String)
@@ -410,6 +419,12 @@ class StateStoreSnapshotPartitionNotFound(
"operatorId" -> operatorId.toString,
"checkpointLocation" -> checkpointLocation))
+class StateStoreFailedToGetChangelogWriter(version: Long, e: Throwable)
+ extends SparkRuntimeException(
+ errorClass = "CANNOT_LOAD_STATE_STORE.FAILED_TO_GET_CHANGELOG_WRITER",
+ messageParameters = Map("version" -> version.toString),
+ cause = e)
+
class StateStoreKeyRowFormatValidationFailure(errorMsg: String)
extends SparkRuntimeException(
errorClass = "STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE",
@@ -424,3 +439,9 @@ class StateStoreProviderDoesNotSupportFineGrainedReplay(inputClass: String)
extends SparkUnsupportedOperationException(
errorClass = "STATE_STORE_PROVIDER_DOES_NOT_SUPPORT_FINE_GRAINED_STATE_REPLAY",
messageParameters = Map("inputClass" -> inputClass))
+
+class StateStoreOperationOutOfOrder(errorMsg: String)
+ extends SparkRuntimeException(
+ errorClass = "STATE_STORE_OPERATION_OUT_OF_ORDER",
+ messageParameters = Map("errorMsg" -> errorMsg)
+ )
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index d51db6e606e13..70b4932af6017 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -136,6 +136,13 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
stateSchemaBroadcast,
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value,
useMultipleValuesPerKey)
+
+ if (storeConf.unloadOnCommit) {
+ ctxt.addTaskCompletionListener[Unit](_ => {
+ StateStore.doMaintenanceAndUnload(storeProviderId)
+ })
+ }
+
storeUpdateFunction(store, inputIter)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index 66ab0006c4982..6ec197d7cc7b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -58,6 +58,9 @@ import org.apache.spark.util.NextIterator
* store providers being used in this class. If true, Spark will
* take care of management for state store providers, e.g. running
* maintenance task for these providers.
+ * @param joinStoreGenerator The generator to create state store instances, re-using the same
+ * instance when the join implementation uses virtual column families
+ * for join version 3.
*
* Internally, the key -> multiple values is stored in two [[StateStore]]s.
* - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key -> number of values
@@ -78,8 +81,8 @@ import org.apache.spark.util.NextIterator
* by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex),
* decrement corresponding num values in KeyToNumValuesStore
*/
-class SymmetricHashJoinStateManager(
- val joinSide: JoinSide,
+abstract class SymmetricHashJoinStateManager(
+ joinSide: JoinSide,
inputValueAttributes: Seq[Attribute],
joinKeys: Seq[Expression],
stateInfo: Option[StatefulOperatorStateInfo],
@@ -91,9 +94,16 @@ class SymmetricHashJoinStateManager(
stateFormatVersion: Int,
skippedNullValueCount: Option[SQLMetric] = None,
useStateStoreCoordinator: Boolean = true,
- snapshotStartVersion: Option[Long] = None) extends Logging {
+ snapshotStartVersion: Option[Long] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator) extends Logging {
import SymmetricHashJoinStateManager._
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ protected val keyToNumValues = new KeyToNumValuesStore(stateFormatVersion)
+ protected val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion)
+
/*
=====================================================
Public methods
@@ -403,55 +413,26 @@ class SymmetricHashJoinStateManager(
def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow = keyProjection(currentKey)
/** Commit all the changes to all the state stores */
- def commit(): Unit = {
- keyToNumValues.commit()
- keyWithIndexToValue.commit()
- }
+ def commit(): Unit
/** Abort any changes to the state stores if needed */
- def abortIfNeeded(): Unit = {
- keyToNumValues.abortIfNeeded()
- keyWithIndexToValue.abortIfNeeded()
- }
+ def abortIfNeeded(): Unit
/**
* Get state store checkpoint information of the two state stores for this joiner, after
* they finished data processing.
+ *
+ * For [[SymmetricHashJoinStateManagerV1]], this returns the information of the two stores
+ * used for this joiner.
+ *
+ * For [[SymmetricHashJoinStateManagerV2]], this returns the information of the single store
+ * used for the entire joiner operator. Both fields of JoinerStateStoreCkptInfo will
+ * be identical.
*/
- def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo = {
- val keyToNumValuesCkptInfo = keyToNumValues.getLatestCheckpointInfo()
- val keyWithIndexToValueCkptInfo = keyWithIndexToValue.getLatestCheckpointInfo()
-
- assert(
- keyToNumValuesCkptInfo.partitionId == keyWithIndexToValueCkptInfo.partitionId,
- "two state stores in a stream-stream joiner don't return the same partition ID")
- assert(
- keyToNumValuesCkptInfo.batchVersion == keyWithIndexToValueCkptInfo.batchVersion,
- "two state stores in a stream-stream joiner don't return the same batch version")
- assert(
- keyToNumValuesCkptInfo.stateStoreCkptId.isDefined ==
- keyWithIndexToValueCkptInfo.stateStoreCkptId.isDefined,
- "two state stores in a stream-stream joiner should both return checkpoint ID or not")
-
- JoinerStateStoreCkptInfo(keyToNumValuesCkptInfo, keyWithIndexToValueCkptInfo)
- }
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
/** Get the combined metrics of all the state stores */
- def metrics: StateStoreMetrics = {
- val keyToNumValuesMetrics = keyToNumValues.metrics
- val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics
- def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc"
-
- StateStoreMetrics(
- keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once
- keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes,
- keyWithIndexToValueMetrics.customMetrics.map {
- case (metric, value) => (metric.withNewDesc(desc = newDesc(metric.desc)), value)
- },
- // We want to collect instance metrics from both state stores
- keyWithIndexToValueMetrics.instanceMetrics ++ keyToNumValuesMetrics.instanceMetrics
- )
- }
+ def metrics: StateStoreMetrics
/**
* Update number of values for a key.
@@ -468,17 +449,11 @@ class SymmetricHashJoinStateManager(
=====================================================
*/
- private val keySchema = StructType(
- joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) })
- private val keyAttributes = toAttributes(keySchema)
- private val keyToNumValues = new KeyToNumValuesStore()
- private val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion)
-
// Clean up any state store resources if necessary at the end of the task
Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } }
/** Helper trait for invoking common functionalities of a state store. */
- private abstract class StateStoreHandler(
+ protected abstract class StateStoreHandler(
stateStoreType: StateStoreType,
stateStoreCkptId: Option[String]) extends Logging {
private var stateStoreProvider: StateStoreProvider = _
@@ -510,21 +485,28 @@ class SymmetricHashJoinStateManager(
}
/** Get the StateStore with the given schema */
- protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = {
- val storeProviderId = StateStoreProviderId(
- stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
+ protected def getStateStore(
+ keySchema: StructType,
+ valueSchema: StructType,
+ useVirtualColumnFamilies: Boolean): StateStore = {
+ val storeName = if (useVirtualColumnFamilies) {
+ StateStoreId.DEFAULT_STORE_NAME
+ } else {
+ getStateStoreName(joinSide, stateStoreType)
+ }
+ val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId, storeName)
val store = if (useStateStoreCoordinator) {
assert(snapshotStartVersion.isEmpty, "Should not use state store coordinator " +
"when reading state as data source.")
- StateStore.get(
+ joinStoreGenerator.getStore(
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
- stateInfo.get.storeVersion, stateStoreCkptId, None, useColumnFamilies = false,
+ stateInfo.get.storeVersion, stateStoreCkptId, None, useVirtualColumnFamilies,
storeConf, hadoopConf)
} else {
// This class will manage the state store provider by itself.
stateStoreProvider = StateStoreProvider.createAndInit(
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
- useColumnFamilies = false, storeConf, hadoopConf,
+ useColumnFamilies = useVirtualColumnFamilies, storeConf, hadoopConf,
useMultipleValuesPerKey = false, stateSchemaProvider = None)
if (snapshotStartVersion.isDefined) {
if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
@@ -546,7 +528,7 @@ class SymmetricHashJoinStateManager(
* Helper class for representing data returned by [[KeyWithIndexToValueStore]].
* Designed for object reuse.
*/
- private class KeyAndNumValues(var key: UnsafeRow = null, var numValue: Long = 0) {
+ private[state] class KeyAndNumValues(var key: UnsafeRow = null, var numValue: Long = 0) {
def withNew(newKey: UnsafeRow, newNumValues: Long): this.type = {
this.key = newKey
this.numValue = newNumValues
@@ -556,16 +538,37 @@ class SymmetricHashJoinStateManager(
/** A wrapper around a [[StateStore]] that stores [key -> number of values]. */
- private class KeyToNumValuesStore
+ protected class KeyToNumValuesStore(val stateFormatVersion: Int)
extends StateStoreHandler(KeyToNumValuesType, keyToNumValuesStateStoreCkptId) {
+
+ private val useVirtualColumnFamilies = stateFormatVersion == 3
private val longValueSchema = new StructType().add("value", "long")
private val longToUnsafeRow = UnsafeProjection.create(longValueSchema)
private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema))
- protected val stateStore: StateStore = getStateStore(keySchema, longValueSchema)
+ protected val stateStore: StateStore =
+ getStateStore(keySchema, longValueSchema, useVirtualColumnFamilies)
+
+ // Set up virtual column family name in the store if it is being used
+ private val colFamilyName = if (useVirtualColumnFamilies) {
+ getStateStoreName(joinSide, KeyToNumValuesType)
+ } else {
+ StateStore.DEFAULT_COL_FAMILY_NAME
+ }
+
+ // Create the specific column family in the store for this join side's KeyToNumValuesStore
+ if (useVirtualColumnFamilies) {
+ stateStore.createColFamilyIfAbsent(
+ colFamilyName,
+ keySchema,
+ longValueSchema,
+ NoPrefixKeyStateEncoderSpec(keySchema),
+ isInternal = true
+ )
+ }
/** Get the number of values the key has */
def get(key: UnsafeRow): Long = {
- val longValueRow = stateStore.get(key)
+ val longValueRow = stateStore.get(key, colFamilyName)
if (longValueRow != null) longValueRow.getLong(0) else 0L
}
@@ -573,16 +576,16 @@ class SymmetricHashJoinStateManager(
def put(key: UnsafeRow, numValues: Long): Unit = {
require(numValues > 0)
valueRow.setLong(0, numValues)
- stateStore.put(key, valueRow)
+ stateStore.put(key, valueRow, colFamilyName)
}
def remove(key: UnsafeRow): Unit = {
- stateStore.remove(key)
+ stateStore.remove(key, colFamilyName)
}
def iterator: Iterator[KeyAndNumValues] = {
val keyAndNumValues = new KeyAndNumValues()
- stateStore.iterator().map { pair =>
+ stateStore.iterator(colFamilyName).map { pair =>
keyAndNumValues.withNew(pair.key, pair.value.getLong(0))
}
}
@@ -592,7 +595,7 @@ class SymmetricHashJoinStateManager(
* Helper class for representing data returned by [[KeyWithIndexToValueStore]].
* Designed for object reuse.
*/
- private class KeyWithIndexAndValue(
+ private[state] class KeyWithIndexAndValue(
var key: UnsafeRow = null,
var valueIndex: Long = -1,
var value: UnsafeRow = null,
@@ -653,7 +656,7 @@ class SymmetricHashJoinStateManager(
private object KeyWithIndexToValueRowConverter {
def create(version: Int): KeyWithIndexToValueRowConverter = version match {
case 1 => new KeyWithIndexToValueRowConverterFormatV1()
- case 2 => new KeyWithIndexToValueRowConverterFormatV2()
+ case 2 | 3 => new KeyWithIndexToValueRowConverterFormatV2()
case _ => throw new IllegalArgumentException("Incorrect state format version! " +
s"version $version")
}
@@ -703,9 +706,10 @@ class SymmetricHashJoinStateManager(
* A wrapper around a [[StateStore]] that stores the mapping; the mapping depends on the
* state format version - please refer implementations of [[KeyWithIndexToValueRowConverter]].
*/
- private class KeyWithIndexToValueStore(stateFormatVersion: Int)
+ protected class KeyWithIndexToValueStore(stateFormatVersion: Int)
extends StateStoreHandler(KeyWithIndexToValueType, keyWithIndexToValueStateStoreCkptId) {
+ private val useVirtualColumnFamilies = stateFormatVersion == 3
private val keyWithIndexExprs = keyAttributes :+ Literal(1L)
private val keyWithIndexSchema = keySchema.add("index", LongType)
private val indexOrdinalInKeyWithIndexRow = keyAttributes.size
@@ -720,10 +724,29 @@ class SymmetricHashJoinStateManager(
private val valueRowConverter = KeyWithIndexToValueRowConverter.create(stateFormatVersion)
protected val stateStore = getStateStore(keyWithIndexSchema,
- valueRowConverter.valueAttributes.toStructType)
+ valueRowConverter.valueAttributes.toStructType, useVirtualColumnFamilies)
+
+ // Set up virtual column family name in the store if it is being used
+ private val colFamilyName = if (useVirtualColumnFamilies) {
+ getStateStoreName(joinSide, KeyWithIndexToValueType)
+ } else {
+ StateStore.DEFAULT_COL_FAMILY_NAME
+ }
+
+ // Create the specific column family in the store for this join side's KeyWithIndexToValueStore
+ if (useVirtualColumnFamilies) {
+ stateStore.createColFamilyIfAbsent(
+ colFamilyName,
+ keySchema,
+ valueRowConverter.valueAttributes.toStructType,
+ NoPrefixKeyStateEncoderSpec(keySchema)
+ )
+ }
def get(key: UnsafeRow, valueIndex: Long): ValueAndMatchPair = {
- valueRowConverter.convertValue(stateStore.get(keyWithIndexRow(key, valueIndex)))
+ valueRowConverter.convertValue(
+ stateStore.get(keyWithIndexRow(key, valueIndex), colFamilyName)
+ )
}
/**
@@ -741,7 +764,8 @@ class SymmetricHashJoinStateManager(
override protected def getNext(): KeyWithIndexAndValue = {
while (hasMoreValues) {
val keyWithIndex = keyWithIndexRow(key, index)
- val valuePair = valueRowConverter.convertValue(stateStore.get(keyWithIndex))
+ val valuePair =
+ valueRowConverter.convertValue(stateStore.get(keyWithIndex, colFamilyName))
if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) {
skippedNullValueCount.foreach(_ += 1L)
index += 1
@@ -764,7 +788,7 @@ class SymmetricHashJoinStateManager(
def put(key: UnsafeRow, valueIndex: Long, value: UnsafeRow, matched: Boolean): Unit = {
val keyWithIndex = keyWithIndexRow(key, valueIndex)
val valueWithMatched = valueRowConverter.convertToValueRow(value, matched)
- stateStore.put(keyWithIndex, valueWithMatched)
+ stateStore.put(keyWithIndex, valueWithMatched, colFamilyName)
}
/**
@@ -772,21 +796,21 @@ class SymmetricHashJoinStateManager(
* (key, index) and it is upto the caller to deal with it.
*/
def remove(key: UnsafeRow, valueIndex: Long): Unit = {
- stateStore.remove(keyWithIndexRow(key, valueIndex))
+ stateStore.remove(keyWithIndexRow(key, valueIndex), colFamilyName)
}
/** Remove all values (i.e. all the indices) for the given key. */
def removeAllValues(key: UnsafeRow, numValues: Long): Unit = {
var index = 0
while (index < numValues) {
- stateStore.remove(keyWithIndexRow(key, index))
+ stateStore.remove(keyWithIndexRow(key, index), colFamilyName)
index += 1
}
}
def iterator: Iterator[KeyWithIndexAndValue] = {
val keyWithIndexAndValue = new KeyWithIndexAndValue()
- stateStore.iterator().map { pair =>
+ stateStore.iterator(colFamilyName).map { pair =>
val valuePair = valueRowConverter.convertValue(pair.value)
keyWithIndexAndValue.withNew(
keyRowGenerator(pair.key), pair.key.getLong(indexOrdinalInKeyWithIndexRow), valuePair)
@@ -803,10 +827,232 @@ class SymmetricHashJoinStateManager(
}
}
+/**
+ * Streaming join state manager that uses 4 state stores without virtual column families.
+ * This implementation creates a state stores based on the join side and the type of state store.
+ *
+ * The keyToNumValues store tracks the number of rows for each key, and the keyWithIndexToValue
+ * store contains the actual entries with an additional index column.
+ */
+class SymmetricHashJoinStateManagerV1(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotStartVersion: Option[Long] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator) extends SymmetricHashJoinStateManager(
+ joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
+ partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
+ stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotStartVersion,
+ joinStoreGenerator) {
+
+ /** Commit all the changes to all the state stores */
+ override def commit(): Unit = {
+ keyToNumValues.commit()
+ keyWithIndexToValue.commit()
+ }
+
+ /** Abort any changes to the state stores if needed */
+ override def abortIfNeeded(): Unit = {
+ keyToNumValues.abortIfNeeded()
+ keyWithIndexToValue.abortIfNeeded()
+ }
+
+ /**
+ * Get state store checkpoint information of the two state stores for this joiner, after
+ * they finished data processing.
+ *
+ * For [[SymmetricHashJoinStateManagerV1]], this returns the information of the two stores
+ * used for this joiner.
+ */
+ override def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo = {
+ val keyToNumValuesCkptInfo = keyToNumValues.getLatestCheckpointInfo()
+ val keyWithIndexToValueCkptInfo = keyWithIndexToValue.getLatestCheckpointInfo()
+
+ assert(
+ keyToNumValuesCkptInfo.partitionId == keyWithIndexToValueCkptInfo.partitionId,
+ "two state stores in a stream-stream joiner don't return the same partition ID")
+ assert(
+ keyToNumValuesCkptInfo.batchVersion == keyWithIndexToValueCkptInfo.batchVersion,
+ "two state stores in a stream-stream joiner don't return the same batch version")
+ assert(
+ keyToNumValuesCkptInfo.stateStoreCkptId.isDefined ==
+ keyWithIndexToValueCkptInfo.stateStoreCkptId.isDefined,
+ "two state stores in a stream-stream joiner should both return checkpoint ID or not")
+
+ JoinerStateStoreCkptInfo(keyToNumValuesCkptInfo, keyWithIndexToValueCkptInfo)
+ }
+
+ override def metrics: StateStoreMetrics = {
+ val keyToNumValuesMetrics = keyToNumValues.metrics
+ val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics
+ def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc"
+
+ StateStoreMetrics(
+ keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once
+ keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes,
+ keyWithIndexToValueMetrics.customMetrics.map {
+ case (metric, value) => (metric.withNewDesc(desc = newDesc(metric.desc)), value)
+ },
+ // We want to collect instance metrics from both state stores
+ keyWithIndexToValueMetrics.instanceMetrics ++ keyToNumValuesMetrics.instanceMetrics
+ )
+ }
+}
+
+/**
+ * Streaming join state manager that uses 1 state store with virtual column families enabled.
+ * Instead of creating a new state store per join side and store type, this manager
+ * uses column families to distinguish data between the original 4 state stores.
+ */
+class SymmetricHashJoinStateManagerV2(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotStartVersion: Option[Long] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator) extends SymmetricHashJoinStateManager(
+ joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
+ partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
+ stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotStartVersion,
+ joinStoreGenerator) {
+
+ /** Commit all the changes to the state store */
+ override def commit(): Unit = {
+ // Both keyToNumValues and keyWithIndexToValue are using the same state store, so only
+ // one commit is needed.
+ keyToNumValues.commit()
+ }
+
+ /** Abort any changes to the state store if needed */
+ override def abortIfNeeded(): Unit = {
+ keyToNumValues.abortIfNeeded()
+ }
+
+ /**
+ * Get state store checkpoint information of the state store used for this joiner, after
+ * they finished data processing.
+ *
+ * For [[SymmetricHashJoinStateManagerV2]], this returns the information of the single store
+ * used for the entire joiner operator. Both fields of JoinerStateStoreCkptInfo will
+ * be identical.
+ */
+ override def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo = {
+ // Note that both keyToNumValues and keyWithIndexToValue are using the same state store,
+ // so the latest checkpoint info should be the same.
+ // These are returned in a JoinerStateStoreCkptInfo object to remain consistent with
+ // the V1 implementation.
+ val keyToNumValuesCkptInfo = keyToNumValues.getLatestCheckpointInfo()
+ val keyWithIndexToValueCkptInfo = keyWithIndexToValue.getLatestCheckpointInfo()
+
+ assert(keyToNumValuesCkptInfo == keyWithIndexToValueCkptInfo)
+
+ JoinerStateStoreCkptInfo(keyToNumValuesCkptInfo, keyWithIndexToValueCkptInfo)
+ }
+
+ /** Get the state store metrics from the state store manager */
+ override def metrics: StateStoreMetrics = keyToNumValues.metrics
+}
+
+/** Class used to handle state store creation in SymmetricHashJoinStateManager V1 and V2 */
+class JoinStateManagerStoreGenerator() extends Logging {
+
+ // Store internally the store used for the manager if virtual column families are enabled
+ private var _store: Option[StateStore] = None
+
+ /**
+ * Creates the state store used for join operations, or returns the existing instance
+ * if it has been previously created and virtual column families are enabled.
+ */
+ def getStore(
+ storeProviderId: StateStoreProviderId,
+ keySchema: StructType,
+ valueSchema: StructType,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ version: Long,
+ stateStoreCkptId: Option[String],
+ stateSchemaBroadcast: Option[StateSchemaBroadcast],
+ useColumnFamilies: Boolean,
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration): StateStore = {
+ if (useColumnFamilies) {
+ // Get the store if we haven't created it yet, otherwise use the one we just created
+ if (_store.isEmpty) {
+ _store = Some(
+ StateStore.get(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, version,
+ stateStoreCkptId, stateSchemaBroadcast, useColumnFamilies = useColumnFamilies,
+ storeConf, hadoopConf
+ )
+ )
+ }
+ _store.get
+ } else {
+ // Do not use the store saved internally, as we need to create the four distinct stores
+ StateStore.get(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, version,
+ stateStoreCkptId, stateSchemaBroadcast, useColumnFamilies = useColumnFamilies,
+ storeConf, hadoopConf
+ )
+ }
+ }
+}
+
object SymmetricHashJoinStateManager {
- val supportedVersions = Seq(1, 2)
+ val supportedVersions = Seq(1, 2, 3)
val legacyVersion = 1
+ // scalastyle:off argcount
+ /** Factory method to determines which version of the join state manager should be created */
+ def apply(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotStartVersion: Option[Long] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator): SymmetricHashJoinStateManager = {
+ if (stateFormatVersion == 3) {
+ new SymmetricHashJoinStateManagerV2(
+ joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
+ partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
+ stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotStartVersion,
+ joinStoreGenerator
+ )
+ } else {
+ new SymmetricHashJoinStateManagerV1(
+ joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
+ partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
+ stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotStartVersion,
+ joinStoreGenerator
+ )
+ }
+ }
+ // scalastyle:on
+
def allStateStoreNames(joinSides: JoinSide*): Seq[String] = {
val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType)
for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield {
@@ -831,7 +1077,7 @@ object SymmetricHashJoinStateManager {
val keyWithIndexSchema = keySchema.add("index", LongType)
val valueSchema = if (stateFormatVersion == 1) {
inputValueAttributes
- } else if (stateFormatVersion == 2) {
+ } else if (stateFormatVersion == 2 || stateFormatVersion == 3) {
inputValueAttributes :+ AttributeReference("matched", BooleanType)()
} else {
throw new IllegalArgumentException("Incorrect state format version! " +
@@ -843,6 +1089,25 @@ object SymmetricHashJoinStateManager {
result
}
+ /** Retrieves the schemas used for join operator state stores that use column families */
+ def getSchemasForStateStoreWithColFamily(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateFormatVersion: Int): Map[String, StateStoreColFamilySchema] = {
+ // Convert the original schemas for state stores into StateStoreColFamilySchema objects
+ val schemas =
+ getSchemaForStateStores(joinSide, inputValueAttributes, joinKeys, stateFormatVersion)
+
+ schemas.map {
+ case (colFamilyName, (keySchema, valueSchema)) =>
+ colFamilyName -> StateStoreColFamilySchema(
+ colFamilyName, 0, keySchema, 0, valueSchema,
+ Some(NoPrefixKeyStateEncoderSpec(keySchema))
+ )
+ }
+ }
+
/**
* Stream-stream join has 4 state stores instead of one. So it will generate 4 different
* checkpoint IDs. The approach we take here is to merge them into one array in the checkpointing
@@ -901,33 +1166,40 @@ object SymmetricHashJoinStateManager {
*/
def getStateStoreCheckpointIds(
partitionId: Int,
- stateInfo: StatefulOperatorStateInfo): JoinStateStoreCheckpointId = {
-
- val stateStoreCkptIds = stateInfo
- .stateStoreCkptIds
- .map(_(partitionId))
- .map(_.map(Option(_)))
- .getOrElse(Array.fill[Option[String]](4)(None))
- JoinStateStoreCheckpointId(
- left = JoinerStateStoreCheckpointId(
- keyToNumValues = stateStoreCkptIds(0),
- valueToNumKeys = stateStoreCkptIds(1)),
- right = JoinerStateStoreCheckpointId(
- keyToNumValues = stateStoreCkptIds(2),
- valueToNumKeys = stateStoreCkptIds(3)))
+ stateInfo: StatefulOperatorStateInfo,
+ useColumnFamiliesForJoins: Boolean): JoinStateStoreCheckpointId = {
+ if (useColumnFamiliesForJoins) {
+ val ckpt = stateInfo.stateStoreCkptIds.map(_(partitionId)).map(_.head)
+ JoinStateStoreCheckpointId(
+ left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt),
+ right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt)
+ )
+ } else {
+ val stateStoreCkptIds = stateInfo.stateStoreCkptIds
+ .map(_(partitionId))
+ .map(_.map(Option(_)))
+ .getOrElse(Array.fill[Option[String]](4)(None))
+ JoinStateStoreCheckpointId(
+ left = JoinerStateStoreCheckpointId(
+ keyToNumValues = stateStoreCkptIds(0),
+ valueToNumKeys = stateStoreCkptIds(1)),
+ right = JoinerStateStoreCheckpointId(
+ keyToNumValues = stateStoreCkptIds(2),
+ valueToNumKeys = stateStoreCkptIds(3)))
+ }
}
- private sealed trait StateStoreType
+ private[state] sealed trait StateStoreType
- private case object KeyToNumValuesType extends StateStoreType {
+ private[state] case object KeyToNumValuesType extends StateStoreType {
override def toString(): String = "keyToNumValues"
}
- private case object KeyWithIndexToValueType extends StateStoreType {
+ private[state] case object KeyWithIndexToValueType extends StateStoreType {
override def toString(): String = "keyWithIndexToValue"
}
- private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = {
+ private[state] def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = {
s"$joinSide-$storeType"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index af47229dfa88c..d92e5dbae1aa1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
@@ -203,27 +203,46 @@ trait StateStoreWriter
def operatorStateMetadataVersion: Int = 1
- override lazy val metrics = statefulOperatorCustomMetrics ++ Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
- "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext,
- "number of rows which are dropped by watermark"),
- "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
- "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"),
- "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to update"),
- "numRemovedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of removed state rows"),
- "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to remove"),
- "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes"),
- "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by state"),
- "numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext,
- "number of state store instances")
- ) ++ stateStoreCustomMetrics ++ pythonMetrics ++ stateStoreInstanceMetrics
+ override lazy val metrics = {
+ // Lazy initialize instance metrics, but do not include these with regular metrics
+ instanceMetrics
+ statefulOperatorCustomMetrics ++ Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ "numRowsDroppedByWatermark" -> SQLMetrics
+ .createMetric(sparkContext, "number of rows which are dropped by watermark"),
+ "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
+ "numUpdatedStateRows" -> SQLMetrics
+ .createMetric(sparkContext, "number of updated state rows"),
+ "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to update"),
+ "numRemovedStateRows" -> SQLMetrics
+ .createMetric(sparkContext, "number of removed state rows"),
+ "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to remove"),
+ "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes"),
+ "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by state"),
+ "numStateStoreInstances" -> SQLMetrics
+ .createMetric(sparkContext, "number of state store instances")
+ ) ++ stateStoreCustomMetrics ++ pythonMetrics
+ }
- val stateStoreNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)
+ /**
+ * Map of all instance metrics (including partition ID and store names) to
+ * their SQLMetric counterpart.
+ *
+ * The instance metric objects hold additional information on how to report these metrics,
+ * while the SQLMetric objects store the metric values.
+ *
+ * This map is similar to the metrics map, but needs to be kept separate to prevent propagating
+ * all initialized instance metrics to SparkUI.
+ */
+ lazy val instanceMetrics: Map[StateStoreInstanceMetric, SQLMetric] =
+ stateStoreInstanceMetrics
+
+ override def resetMetrics(): Unit = {
+ super.resetMetrics()
+ instanceMetrics.valuesIterator.foreach(_.reset())
+ }
- // This is used to relate metric names back to their original metric object,
- // which holds information on how to report the metric during getProgress.
- lazy val instanceMetricConfiguration: Map[String, StateStoreInstanceMetric] =
- stateStoreInstanceMetricObjects
+ val stateStoreNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)
// This method is only used to fetch the state schema directory path for
// operators that use StateSchemaV3, as prior versions only use a single
@@ -327,28 +346,31 @@ trait StateStoreWriter
* the driver after this SparkPlan has been executed and metrics have been updated.
*/
def getProgress(): StateOperatorProgress = {
- val instanceMetricsToReport = instanceMetricConfiguration
+ val instanceMetricsToReport = instanceMetrics
.filter {
- case (name, metricConfig) =>
+ case (metricConf, sqlMetric) =>
// Keep instance metrics that are updated or aren't marked to be ignored,
// as their initial value could still be important.
- !metricConfig.ignoreIfUnchanged || !longMetric(name).isZero
+ !metricConf.ignoreIfUnchanged || !sqlMetric.isZero
}
.groupBy {
// Group all instance metrics underneath their common metric prefix
// to ignore partition and store names.
- case (name, metricConfig) => metricConfig.metricPrefix
+ case (metricConf, sqlMetric) => metricConf.metricPrefix
}
.flatMap {
case (_, metrics) =>
// Select at most N metrics based on the metric's defined ordering
// to report to the driver. For example, ascending order would be taking the N smallest.
- val metricConf = metrics.head._2
+ val metricConf = metrics.head._1
metrics
.map {
- case (_, metric) =>
- metric.name -> (if (longMetric(metric.name).isZero) metricConf.initValue
- else longMetric(metric.name).value)
+ case (metricConf, sqlMetric) =>
+ // Use metric name as it will be combined with custom metrics in progress reports.
+ // All metrics that are at their initial value at this stage should not be ignored
+ // and should show their real initial value.
+ metricConf.name -> (if (sqlMetric.isZero) metricConf.initValue
+ else sqlMetric.value)
}
.toSeq
.sortBy(_._2)(metricConf.ordering)
@@ -434,12 +456,11 @@ trait StateStoreWriter
}
protected def setStoreInstanceMetrics(
- instanceMetrics: Map[StateStoreInstanceMetric, Long]): Unit = {
- instanceMetrics.foreach {
+ otherStoreInstanceMetrics: Map[StateStoreInstanceMetric, Long]): Unit = {
+ otherStoreInstanceMetrics.foreach {
case (metric, value) =>
- val metricConfig = instanceMetricConfiguration(metric.name)
// Update the metric's value based on the defined combine method
- longMetric(metric.name).set(metricConfig.combine(longMetric(metric.name), value))
+ instanceMetrics(metric).set(metric.combine(instanceMetrics(metric), value))
}
}
@@ -450,13 +471,7 @@ trait StateStoreWriter
}.toMap
}
- private def stateStoreInstanceMetrics: Map[String, SQLMetric] = {
- instanceMetricConfiguration.map {
- case (name, metric) => (name, metric.createSQLMetric(sparkContext))
- }
- }
-
- private def stateStoreInstanceMetricObjects: Map[String, StateStoreInstanceMetric] = {
+ private def stateStoreInstanceMetrics: Map[StateStoreInstanceMetric, SQLMetric] = {
val provider = StateStoreProvider.create(conf.stateStoreProviderClass)
val maxPartitions = stateInfo.map(_.numPartitions).getOrElse(conf.defaultNumShufflePartitions)
@@ -464,7 +479,7 @@ trait StateStoreWriter
provider.supportedInstanceMetrics.flatMap { metric =>
stateStoreNames.map { storeName =>
val metricWithPartition = metric.withNewId(partitionId, storeName)
- (metricWithPartition.name, metricWithPartition)
+ (metricWithPartition, metricWithPartition.createSQLMetric(sparkContext))
}
}
}.toMap
@@ -1475,3 +1490,59 @@ case class StreamingDeduplicateWithinWatermarkExec(
override protected def withNewChildInternal(
newChild: SparkPlan): StreamingDeduplicateWithinWatermarkExec = copy(child = newChild)
}
+
+trait SchemaValidationUtils extends Logging {
+
+ // Determines whether the operator should be able to evolve their schema
+ val schemaEvolutionEnabledForOperator: Boolean = false
+
+ // This method will return the column family schemas, and check whether the fields in the
+ // schema are nullable. If Avro encoding is used, we want to enforce nullability
+ def getColFamilySchemas(shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema]
+
+ def validateAndWriteStateSchema(
+ hadoopConf: Configuration,
+ batchId: Long,
+ stateSchemaVersion: Int,
+ info: StatefulOperatorStateInfo,
+ stateSchemaDir: Path,
+ session: SparkSession,
+ operatorStateMetadataVersion: Int = 2,
+ stateStoreEncodingFormat: String = StateStoreEncoding.UnsafeRow.toString
+ ): List[StateSchemaValidationResult] = {
+ assert(stateSchemaVersion >= 3)
+ val usingAvro = stateStoreEncodingFormat == StateStoreEncoding.Avro.toString
+ val newSchemas = getColFamilySchemas(usingAvro)
+ val newStateSchemaFilePath =
+ new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")
+ val metadataPath = new Path(info.checkpointLocation, s"${info.operatorId}")
+ val metadataReader = OperatorStateMetadataReader.createReader(
+ metadataPath, hadoopConf, operatorStateMetadataVersion, batchId)
+ val operatorStateMetadata = try {
+ metadataReader.read()
+ } catch {
+ // If this is the first time we are running the query, there will be no metadata
+ // and this error is expected. In this case, we return None.
+ case _: Exception if batchId == 0 =>
+ None
+ }
+
+ val oldStateSchemaFilePaths: List[Path] = operatorStateMetadata match {
+ case Some(metadata) =>
+ metadata match {
+ case v2: OperatorStateMetadataV2 =>
+ v2.stateStoreInfo.head.stateSchemaFilePaths.map(new Path(_))
+ case _ => List.empty
+ }
+ case None => List.empty
+ }
+ // state schema file written here, writing the new schema list we passed here
+ List(StateSchemaCompatibilityChecker.
+ validateAndMaybeEvolveStateSchema(info, hadoopConf,
+ newSchemas.values.toList, session.sessionState, stateSchemaVersion,
+ storeName = StateStoreId.DEFAULT_STORE_NAME,
+ oldSchemaFilePaths = oldStateSchemaFilePaths,
+ newSchemaFilePath = Some(newStateSchemaFilePath),
+ schemaEvolutionEnabled = usingAvro && schemaEvolutionEnabledForOperator))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
index 2a6c15df5d1db..9d10e5f9545c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
@@ -26,6 +26,7 @@ import org.json4s.jackson.JsonMethods.parse
import org.apache.spark.JobExecutionStatus
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.UI.UI_SQL_GROUP_SUB_EXECUTION_ENABLED
import org.apache.spark.ui.{UIUtils, WebUIPage}
class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging {
@@ -33,6 +34,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
private val pandasOnSparkConfPrefix = "pandas_on_Spark."
private val sqlStore = parent.sqlStore
+ private val groupSubExecutionEnabled = parent.conf.get(UI_SQL_GROUP_SUB_EXECUTION_ENABLED)
+
override def render(request: HttpServletRequest): Seq[Node] = {
val parameterExecutionId = request.getParameter("id")
@@ -71,10 +74,46 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
Duration: {UIUtils.formatDuration(duration)}
+ {
+ if (executionUIData.rootExecutionId != executionId) {
+