Skip to content
77 changes: 71 additions & 6 deletions native/spark-expr/src/bloom_filter/spark_bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,76 @@ impl SparkBloomFilter {
}

pub fn merge_filter(&mut self, other: &[u8]) {
assert_eq!(
other.len(),
self.bits.byte_size(),
"Cannot merge SparkBloomFilters with different lengths."
);
self.bits.merge_bits(other);
// Extract bits data if other is in Spark's full serialization format
// We need to compute the expected size and extract data before borrowing self.bits mutably
let expected_bits_size = self.bits.byte_size();
const SPARK_HEADER_SIZE: usize = 12; // version (4) + num_hash_functions (4) + num_words (4)

let bits_data = if other.len() >= SPARK_HEADER_SIZE {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be strictly greater than SPARK_HEADER_SIZE?

// Check if this is Spark's serialization format by reading the version
let version = i32::from_be_bytes([
other[0], other[1], other[2], other[3],
]);
if version == SPARK_BLOOM_FILTER_VERSION_1 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sufficient to ensure that this is a spark bloom filter? Isn't there a chance the starting 4 bytes of the Comet bloom filter might match the pattern?

// This is Spark's full format, parse it to extract bits data
let num_words = i32::from_be_bytes([
other[8], other[9], other[10], other[11],
]) as usize;
let bits_start = SPARK_HEADER_SIZE;
let bits_end = bits_start + (num_words * 8);

// Verify the buffer is large enough
if bits_end > other.len() {
panic!(
"Cannot merge SparkBloomFilters: buffer too short. Expected at least {} bytes ({} words), got {} bytes",
bits_end,
num_words,
other.len()
);
}

// Check if the incoming bloom filter has compatible size
let incoming_bits_size = bits_end - bits_start;
if incoming_bits_size != expected_bits_size {
panic!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use CometError::Internal(String) instead of panic!? (You'll need to return a Result)

"Cannot merge SparkBloomFilters with incompatible sizes. Expected {} bytes ({} words), got {} bytes ({} words) from Spark partial aggregate. Full buffer size: {} bytes",
expected_bits_size,
self.bits.word_size(),
incoming_bits_size,
num_words,
other.len()
);
}

// Extract just the bits portion
&other[bits_start..bits_end]
} else if other.len() == expected_bits_size {
// Not Spark format but size matches, treat as raw bits data (Comet format)
other
} else {
// Size doesn't match and not Spark format - provide helpful error
panic!(
"Cannot merge SparkBloomFilters: unexpected format. Expected {} bytes (Comet format) or Spark format (version 1, {} bytes header + bits), but got {} bytes with version {}",
expected_bits_size,
SPARK_HEADER_SIZE,
other.len(),
version
);
}
} else {
// Too short to be Spark format
if other.len() != expected_bits_size {
panic!(
"Cannot merge SparkBloomFilters: buffer too short. Expected {} bytes (Comet format) or at least {} bytes (Spark format), got {} bytes",
expected_bits_size,
SPARK_HEADER_SIZE,
other.len()
);
}
// Size matches, treat as raw bits data
other
};

self.bits.merge_bits(bits_data);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,13 @@ trait CometBaseAggregate {
val multiMode = modes.size > 1
// For a final mode HashAggregate, we only need to transform the HashAggregate
// if there is Comet partial aggregation.
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
// Exception: BloomFilterAggregate supports Spark partial / Comet final because
// merge_filter() handles Spark's serialization format (12-byte header + bits).
val hasBloomFilterAgg = aggregate.aggregateExpressions.exists(expr =>
expr.aggregateFunction.getClass.getSimpleName == "BloomFilterAggregate")
val sparkFinalMode = modes.contains(Final) &&
findCometPartialAgg(aggregate.child).isEmpty &&
!hasBloomFilterAgg

if (multiMode || sparkFinalMode) {
return None
Expand Down
69 changes: 66 additions & 3 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Hex}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate, Final, Partial}
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SparkPlan, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
Expand Down Expand Up @@ -1149,6 +1150,68 @@ class CometExecSuite extends CometTestBase {
spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
}

test("bloom_filter_agg - Spark partial / Comet final merge") {
// This test exercises the merge_filter() fix that handles Spark's full serialization
// format (12-byte header + bits) when merging from Spark partial to Comet final aggregates.
val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
spark.sessionState.functionRegistry.registerFunction(
funcId_bloom_filter_agg,
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
(children: Seq[Expression]) =>
children.size match {
case 1 => new BloomFilterAggregate(children.head)
case 2 => new BloomFilterAggregate(children.head, children(1))
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
})

// Helper to count operators in plan
def countOperators(plan: SparkPlan, opClass: Class[_]): Int = {
stripAQEPlan(plan).collect {
case stage: QueryStageExec =>
countOperators(stage.plan, opClass)
case op if op.getClass.isAssignableFrom(opClass) => 1
}.sum
}

withParquetTable(
(0 until 1000)
.map(_ => (Random.nextInt(1000), Random.nextInt(100))),
"tbl") {

withSQLConf(
// Disable Comet partial aggregates to force Spark partial / Comet final scenario
CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false",
CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {

val df = sql(
"SELECT bloom_filter_agg(cast(_2 as long), cast(1000 as long)) FROM tbl GROUP BY _1")

// Verify the query executes successfully (tests merge_filter compatibility)
checkSparkAnswer(df)

// Verify we have Spark partial aggregates and Comet final aggregates
val plan = stripAQEPlan(df.queryExecution.executedPlan)
val sparkPartialAggs = plan.collect {
case agg: HashAggregateExec if agg.aggregateExpressions.exists(_.mode == Partial) => agg
}
val cometFinalAggs = plan.collect {
case agg: CometHashAggregateExec if agg.aggregateExpressions.exists(_.mode == Final) =>
agg
}

assert(
sparkPartialAggs.nonEmpty,
s"Expected Spark partial aggregates but found none. Plan: $plan")
assert(
cometFinalAggs.nonEmpty,
s"Expected Comet final aggregates but found none. Plan: $plan")
}
}

spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
}

test("sort (non-global)") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc)
Expand Down
Loading