diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala index 4884eda91..5558d5f06 100644 --- a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -17,6 +17,7 @@ package org.apache.auron.utils import org.apache.spark.sql._ +import org.apache.spark.sql.execution.joins.AuronExistenceJoinSuite class AuronSparkTestSettings extends SparkTestSettings { { @@ -42,6 +43,8 @@ class AuronSparkTestSettings extends SparkTestSettings { enableSuite[AuronTypedImperativeAggregateSuite] + enableSuite[AuronExistenceJoinSuite] + // Will be implemented in the future. override def getSQLQueryTestSettings = new SQLQueryTestSettings { override def getResourceFilePath: String = ??? diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/execution/joins/AuronExistenceJoinSuite.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/execution/joins/AuronExistenceJoinSuite.scala new file mode 100644 index 000000000..f8fc9660c --- /dev/null +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/execution/joins/AuronExistenceJoinSuite.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.SparkTestsSharedSessionBase + +class AuronExistenceJoinSuite extends ExistenceJoinSuite with SparkTestsSharedSessionBase {} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala index 3281947c8..edd172a7d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala @@ -27,12 +27,7 @@ import org.apache.spark.sql.auron.NativeRDD import org.apache.spark.sql.auron.NativeSupports import org.apache.spark.sql.auron.Shims import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi -import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.BinaryExecNode import org.apache.spark.sql.execution.SparkPlan @@ -43,7 +38,7 @@ import org.apache.spark.sql.types.LongType import org.apache.auron.{protobuf => pb} import org.apache.auron.metric.SparkMetricNode -import org.apache.auron.protobuf.JoinOn +import org.apache.auron.protobuf.{EmptyPartitionsExecNode, JoinOn, PhysicalPlanNode} abstract class NativeBroadcastJoinBase( override val left: SparkPlan, @@ -127,16 +122,35 @@ abstract class NativeBroadcastJoinBase( override def doExecuteNative(): NativeRDD = { val leftRDD = NativeHelper.executeNative(left) val rightRDD = NativeHelper.executeNative(right) - val nativeMetrics = SparkMetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil) - val nativeSchema = this.nativeSchema - val nativeJoinType = this.nativeJoinType - val nativeJoinOn = this.nativeJoinOn val (probedRDD, builtRDD) = broadcastSide match { case BroadcastLeft => (rightRDD, leftRDD) case BroadcastRight => (leftRDD, rightRDD) } + // Handle the edge case when probed side is empty (no partitions) + // This matches Spark's BroadcastNestedLoopJoinExec behavior for condition.isEmpty case: + // val streamExists = !streamed.executeTake(1).isEmpty + // if (streamExists == exists) sparkContext.makeRDD(relation.value) + // else sparkContext.emptyRDD + // where exists = true for Semi, false for Anti + // + // Note: This optimization only applies to Semi/Anti joins. + if (probedRDD.partitions.isEmpty) { + joinType match { + case LeftAnti => + return builtRDD + case LeftSemi => + return probedRDD + case _ => + } + } + + val nativeMetrics = SparkMetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil) + val nativeSchema = this.nativeSchema + val nativeJoinType = this.nativeJoinType + val nativeJoinOn = this.nativeJoinOn + val probedShuffleReadFull = probedRDD.isShuffleReadFull && (broadcastSide match { case BroadcastLeft => Seq(FullOuter, RightOuter).contains(joinType) @@ -144,27 +158,63 @@ abstract class NativeBroadcastJoinBase( Seq(FullOuter, LeftOuter, LeftSemi, LeftAnti).contains(joinType) }) + // For ExistenceJoin with empty probed side, use builtRDD.partitions to ensure + // native join can execute and finish() will output all build rows with exists=false + val (rddPartitions, rddPartitioner, rddDependencies) = + if (probedRDD.partitions.isEmpty && joinType.isInstanceOf[ExistenceJoin]) { + (builtRDD.partitions, builtRDD.partitioner, new OneToOneDependency(builtRDD) :: Nil) + } else { + (probedRDD.partitions, probedRDD.partitioner, new OneToOneDependency(probedRDD) :: Nil) + } + new NativeRDD( sparkContext, nativeMetrics, - probedRDD.partitions, - rddPartitioner = probedRDD.partitioner, - rddDependencies = new OneToOneDependency(probedRDD) :: Nil, + rddPartitions, + rddPartitioner = rddPartitioner, + rddDependencies = rddDependencies, probedShuffleReadFull, (partition, context) => { val partition0 = new Partition() { override def index: Int = 0 } - val (leftChild, rightChild) = broadcastSide match { - case BroadcastLeft => - ( - leftRDD.nativePlan(partition0, context), - rightRDD.nativePlan(rightRDD.partitions(partition.index), context)) - case BroadcastRight => - ( - leftRDD.nativePlan(leftRDD.partitions(partition.index), context), - rightRDD.nativePlan(partition0, context)) - } + val (leftChild, rightChild) = + if (probedRDD.partitions.isEmpty && joinType.isInstanceOf[ExistenceJoin]) { + val probedSchema = broadcastSide match { + case BroadcastLeft => Util.getNativeSchema(right.output) + case BroadcastRight => Util.getNativeSchema(left.output) + } + val emptyProbedPlan = PhysicalPlanNode + .newBuilder() + .setEmptyPartitions( + EmptyPartitionsExecNode + .newBuilder() + .setNumPartitions(1) + .setSchema(probedSchema) + .build()) + .build() + broadcastSide match { + case BroadcastLeft => + ( + leftRDD.nativePlan(leftRDD.partitions(partition.index), context), + emptyProbedPlan) + case BroadcastRight => + ( + emptyProbedPlan, + rightRDD.nativePlan(rightRDD.partitions(partition.index), context)) + } + } else { + broadcastSide match { + case BroadcastLeft => + ( + leftRDD.nativePlan(partition0, context), + rightRDD.nativePlan(rightRDD.partitions(partition.index), context)) + case BroadcastRight => + ( + leftRDD.nativePlan(leftRDD.partitions(partition.index), context), + rightRDD.nativePlan(partition0, context)) + } + } val cachedBuildHashMapId = s"bhm_stage${context.stageId}_rdd${builtRDD.id}" val broadcastJoinExec = pb.BroadcastJoinExecNode