Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.auron.utils

import org.apache.spark.sql._
import org.apache.spark.sql.execution.joins.AuronExistenceJoinSuite

class AuronSparkTestSettings extends SparkTestSettings {
{
Expand All @@ -42,6 +43,8 @@ class AuronSparkTestSettings extends SparkTestSettings {

enableSuite[AuronTypedImperativeAggregateSuite]

enableSuite[AuronExistenceJoinSuite]

// Will be implemented in the future.
override def getSQLQueryTestSettings = new SQLQueryTestSettings {
override def getResourceFilePath: String = ???
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.SparkTestsSharedSessionBase

class AuronExistenceJoinSuite extends ExistenceJoinSuite with SparkTestsSharedSessionBase {}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ import org.apache.spark.sql.auron.NativeRDD
import org.apache.spark.sql.auron.NativeSupports
import org.apache.spark.sql.auron.Shims
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.LeftAnti
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.BinaryExecNode
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -43,7 +38,7 @@ import org.apache.spark.sql.types.LongType

import org.apache.auron.{protobuf => pb}
import org.apache.auron.metric.SparkMetricNode
import org.apache.auron.protobuf.JoinOn
import org.apache.auron.protobuf.{EmptyPartitionsExecNode, JoinOn, PhysicalPlanNode}

abstract class NativeBroadcastJoinBase(
override val left: SparkPlan,
Expand Down Expand Up @@ -127,44 +122,99 @@ abstract class NativeBroadcastJoinBase(
override def doExecuteNative(): NativeRDD = {
val leftRDD = NativeHelper.executeNative(left)
val rightRDD = NativeHelper.executeNative(right)
val nativeMetrics = SparkMetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil)
val nativeSchema = this.nativeSchema
val nativeJoinType = this.nativeJoinType
val nativeJoinOn = this.nativeJoinOn

val (probedRDD, builtRDD) = broadcastSide match {
case BroadcastLeft => (rightRDD, leftRDD)
case BroadcastRight => (leftRDD, rightRDD)
}

// Handle the edge case when probed side is empty (no partitions)
// This matches Spark's BroadcastNestedLoopJoinExec behavior for condition.isEmpty case:
// val streamExists = !streamed.executeTake(1).isEmpty
// if (streamExists == exists) sparkContext.makeRDD(relation.value)
// else sparkContext.emptyRDD
// where exists = true for Semi, false for Anti
//
// Note: This optimization only applies to Semi/Anti joins.
if (probedRDD.partitions.isEmpty) {
joinType match {
case LeftAnti =>
return builtRDD
case LeftSemi =>
return probedRDD
case _ =>
}
}

val nativeMetrics = SparkMetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil)
val nativeSchema = this.nativeSchema
val nativeJoinType = this.nativeJoinType
val nativeJoinOn = this.nativeJoinOn

val probedShuffleReadFull = probedRDD.isShuffleReadFull && (broadcastSide match {
case BroadcastLeft =>
Seq(FullOuter, RightOuter).contains(joinType)
case BroadcastRight =>
Seq(FullOuter, LeftOuter, LeftSemi, LeftAnti).contains(joinType)
})

// For ExistenceJoin with empty probed side, use builtRDD.partitions to ensure
// native join can execute and finish() will output all build rows with exists=false
val (rddPartitions, rddPartitioner, rddDependencies) =
if (probedRDD.partitions.isEmpty && joinType.isInstanceOf[ExistenceJoin]) {
(builtRDD.partitions, builtRDD.partitioner, new OneToOneDependency(builtRDD) :: Nil)
} else {
(probedRDD.partitions, probedRDD.partitioner, new OneToOneDependency(probedRDD) :: Nil)
}

new NativeRDD(
sparkContext,
nativeMetrics,
probedRDD.partitions,
rddPartitioner = probedRDD.partitioner,
rddDependencies = new OneToOneDependency(probedRDD) :: Nil,
rddPartitions,
rddPartitioner = rddPartitioner,
rddDependencies = rddDependencies,
probedShuffleReadFull,
(partition, context) => {
val partition0 = new Partition() {
override def index: Int = 0
}
val (leftChild, rightChild) = broadcastSide match {
case BroadcastLeft =>
(
leftRDD.nativePlan(partition0, context),
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
case BroadcastRight =>
(
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
rightRDD.nativePlan(partition0, context))
}
val (leftChild, rightChild) =
if (probedRDD.partitions.isEmpty && joinType.isInstanceOf[ExistenceJoin]) {
val probedSchema = broadcastSide match {
case BroadcastLeft => Util.getNativeSchema(right.output)
case BroadcastRight => Util.getNativeSchema(left.output)
}
val emptyProbedPlan = PhysicalPlanNode
.newBuilder()
.setEmptyPartitions(
EmptyPartitionsExecNode
.newBuilder()
.setNumPartitions(1)
.setSchema(probedSchema)
.build())
.build()
broadcastSide match {
case BroadcastLeft =>
(
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
emptyProbedPlan)
case BroadcastRight =>
(
emptyProbedPlan,
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
}
} else {
broadcastSide match {
case BroadcastLeft =>
(
leftRDD.nativePlan(partition0, context),
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
case BroadcastRight =>
(
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
rightRDD.nativePlan(partition0, context))
}
}
val cachedBuildHashMapId = s"bhm_stage${context.stageId}_rdd${builtRDD.id}"

val broadcastJoinExec = pb.BroadcastJoinExecNode
Expand Down
Loading