Skip to content

Commit b36354e

Browse files
XorSumCopilot
andauthored
[AURON #1693] join operation should flush in time on duplicated keys (#1701)
<!-- Thanks for sending a pull request! Please keep the following tips in mind: - Start the PR title with the related issue ID, e.g. '[AURON #XXXX] Short summary...'. - Make your PR title clear and descriptive, summarizing what this PR changes. - Provide a concise example to reproduce the issue, if possible. - Keep the PR description up to date with all changes. --> # Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> Closes #1693. # Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> As discussed previously in #1693 and #1694, the join operation should check batch size and trigger flushing in a timely manner, to prevent extreme large batch size. # What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> # Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> # How was this patch tested? <!-- If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent cf8f95f commit b36354e

2 files changed

Lines changed: 222 additions & 4 deletions

File tree

native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ impl<const L_OUTER: bool, const R_OUTER: bool> FullJoiner<L_OUTER, R_OUTER> {
5656
self.lindices.len() >= self.join_params.batch_size
5757
}
5858

59+
fn has_enough_room(&self, new_size: usize) -> bool {
60+
self.lindices.len() + new_size <= self.join_params.batch_size
61+
}
62+
5963
async fn flush(
6064
mut self: Pin<&mut Self>,
6165
cur1: &mut StreamCursor,
@@ -158,9 +162,26 @@ impl<const L_OUTER: bool, const R_OUTER: bool> Joiner for FullJoiner<L_OUTER, R_
158162
continue;
159163
}
160164

161-
for (&lidx, &ridx) in equal_lindices.iter().cartesian_product(&equal_rindices) {
162-
self.lindices.push(lidx);
163-
self.rindices.push(ridx);
165+
let new_size = equal_lindices.len() * equal_rindices.len();
166+
if self.has_enough_room(new_size) {
167+
// old cartesian_product way
168+
for (&lidx, &ridx) in
169+
equal_lindices.iter().cartesian_product(&equal_rindices)
170+
{
171+
self.lindices.push(lidx);
172+
self.rindices.push(ridx);
173+
}
174+
} else {
175+
// do more aggressive flush
176+
for &lidx in &equal_lindices {
177+
for &ridx in &equal_rindices {
178+
self.lindices.push(lidx);
179+
self.rindices.push(ridx);
180+
if self.should_flush() {
181+
self.as_mut().flush(cur1, cur2).await?;
182+
}
183+
}
184+
}
164185
}
165186

166187
if r_equal {

native-engine/datafusion-ext-plans/src/joins/test.rs

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ mod tests {
3131
common::{JoinSide, Result},
3232
physical_expr::expressions::Column,
3333
physical_plan::{ExecutionPlan, common, joins::utils::*, test::TestMemoryExec},
34-
prelude::SessionContext,
34+
prelude::{SessionConfig, SessionContext},
3535
};
3636

3737
use crate::{
@@ -283,6 +283,91 @@ mod tests {
283283
Ok((columns, batches))
284284
}
285285

286+
async fn join_collect_with_batch_size(
287+
test_type: TestType,
288+
left: Arc<dyn ExecutionPlan>,
289+
right: Arc<dyn ExecutionPlan>,
290+
on: JoinOn,
291+
join_type: JoinType,
292+
batch_size: usize,
293+
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
294+
MemManager::init(1000000);
295+
let session_config = SessionConfig::new().with_batch_size(batch_size);
296+
let session_ctx = SessionContext::new_with_config(session_config);
297+
let task_ctx = session_ctx.task_ctx();
298+
let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?;
299+
300+
let join: Arc<dyn ExecutionPlan> = match test_type {
301+
SMJ => {
302+
let sort_options = vec![SortOptions::default(); on.len()];
303+
Arc::new(SortMergeJoinExec::try_new(
304+
schema,
305+
left,
306+
right,
307+
on,
308+
join_type,
309+
sort_options,
310+
)?)
311+
}
312+
BHJLeftProbed => {
313+
let right = Arc::new(BroadcastJoinBuildHashMapExec::new(
314+
right,
315+
on.iter().map(|(_, right_key)| right_key.clone()).collect(),
316+
));
317+
Arc::new(BroadcastJoinExec::try_new(
318+
schema,
319+
left,
320+
right,
321+
on,
322+
join_type,
323+
JoinSide::Right,
324+
true,
325+
None,
326+
)?)
327+
}
328+
BHJRightProbed => {
329+
let left = Arc::new(BroadcastJoinBuildHashMapExec::new(
330+
left,
331+
on.iter().map(|(left_key, _)| left_key.clone()).collect(),
332+
));
333+
Arc::new(BroadcastJoinExec::try_new(
334+
schema,
335+
left,
336+
right,
337+
on,
338+
join_type,
339+
JoinSide::Left,
340+
true,
341+
None,
342+
)?)
343+
}
344+
SHJLeftProbed => Arc::new(BroadcastJoinExec::try_new(
345+
schema,
346+
left,
347+
right,
348+
on,
349+
join_type,
350+
JoinSide::Right,
351+
false,
352+
None,
353+
)?),
354+
SHJRightProbed => Arc::new(BroadcastJoinExec::try_new(
355+
schema,
356+
left,
357+
right,
358+
on,
359+
join_type,
360+
JoinSide::Left,
361+
false,
362+
None,
363+
)?),
364+
};
365+
let columns = columns(&join.schema());
366+
let stream = join.execute(0, task_ctx)?;
367+
let batches = common::collect(stream).await?;
368+
Ok((columns, batches))
369+
}
370+
286371
const ALL_TEST_TYPE: [TestType; 5] = [
287372
SMJ,
288373
BHJLeftProbed,
@@ -447,6 +532,118 @@ mod tests {
447532
Ok(())
448533
}
449534

535+
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
536+
async fn join_inner_batchsize() -> Result<()> {
537+
for test_type in ALL_TEST_TYPE {
538+
let left = build_table(
539+
("a1", &vec![1, 1, 1, 1, 1]),
540+
("b1", &vec![1, 2, 3, 4, 5]),
541+
("c1", &vec![1, 2, 3, 4, 5]),
542+
);
543+
let right = build_table(
544+
("a2", &vec![1, 1, 1, 1, 1, 1, 1]),
545+
("b2", &vec![1, 2, 3, 4, 5, 6, 7]),
546+
("c2", &vec![1, 2, 3, 4, 5, 6, 7]),
547+
);
548+
let on: JoinOn = vec![(
549+
Arc::new(Column::new_with_schema("a1", &left.schema())?),
550+
Arc::new(Column::new_with_schema("a2", &right.schema())?),
551+
)];
552+
let expected = vec![
553+
"+----+----+----+----+----+----+",
554+
"| a1 | b1 | c1 | a2 | b2 | c2 |",
555+
"+----+----+----+----+----+----+",
556+
"| 1 | 1 | 1 | 1 | 1 | 1 |",
557+
"| 1 | 1 | 1 | 1 | 2 | 2 |",
558+
"| 1 | 1 | 1 | 1 | 3 | 3 |",
559+
"| 1 | 1 | 1 | 1 | 4 | 4 |",
560+
"| 1 | 1 | 1 | 1 | 5 | 5 |",
561+
"| 1 | 1 | 1 | 1 | 6 | 6 |",
562+
"| 1 | 1 | 1 | 1 | 7 | 7 |",
563+
"| 1 | 2 | 2 | 1 | 1 | 1 |",
564+
"| 1 | 2 | 2 | 1 | 2 | 2 |",
565+
"| 1 | 2 | 2 | 1 | 3 | 3 |",
566+
"| 1 | 2 | 2 | 1 | 4 | 4 |",
567+
"| 1 | 2 | 2 | 1 | 5 | 5 |",
568+
"| 1 | 2 | 2 | 1 | 6 | 6 |",
569+
"| 1 | 2 | 2 | 1 | 7 | 7 |",
570+
"| 1 | 3 | 3 | 1 | 1 | 1 |",
571+
"| 1 | 3 | 3 | 1 | 2 | 2 |",
572+
"| 1 | 3 | 3 | 1 | 3 | 3 |",
573+
"| 1 | 3 | 3 | 1 | 4 | 4 |",
574+
"| 1 | 3 | 3 | 1 | 5 | 5 |",
575+
"| 1 | 3 | 3 | 1 | 6 | 6 |",
576+
"| 1 | 3 | 3 | 1 | 7 | 7 |",
577+
"| 1 | 4 | 4 | 1 | 1 | 1 |",
578+
"| 1 | 4 | 4 | 1 | 2 | 2 |",
579+
"| 1 | 4 | 4 | 1 | 3 | 3 |",
580+
"| 1 | 4 | 4 | 1 | 4 | 4 |",
581+
"| 1 | 4 | 4 | 1 | 5 | 5 |",
582+
"| 1 | 4 | 4 | 1 | 6 | 6 |",
583+
"| 1 | 4 | 4 | 1 | 7 | 7 |",
584+
"| 1 | 5 | 5 | 1 | 1 | 1 |",
585+
"| 1 | 5 | 5 | 1 | 2 | 2 |",
586+
"| 1 | 5 | 5 | 1 | 3 | 3 |",
587+
"| 1 | 5 | 5 | 1 | 4 | 4 |",
588+
"| 1 | 5 | 5 | 1 | 5 | 5 |",
589+
"| 1 | 5 | 5 | 1 | 6 | 6 |",
590+
"| 1 | 5 | 5 | 1 | 7 | 7 |",
591+
"+----+----+----+----+----+----+",
592+
];
593+
let (_, batches) = join_collect_with_batch_size(
594+
test_type,
595+
left.clone(),
596+
right.clone(),
597+
on.clone(),
598+
Inner,
599+
2,
600+
)
601+
.await?;
602+
assert_batches_sorted_eq!(expected, &batches);
603+
let (_, batches) = join_collect_with_batch_size(
604+
test_type,
605+
left.clone(),
606+
right.clone(),
607+
on.clone(),
608+
Inner,
609+
3,
610+
)
611+
.await?;
612+
assert_batches_sorted_eq!(expected, &batches);
613+
let (_, batches) = join_collect_with_batch_size(
614+
test_type,
615+
left.clone(),
616+
right.clone(),
617+
on.clone(),
618+
Inner,
619+
4,
620+
)
621+
.await?;
622+
assert_batches_sorted_eq!(expected, &batches);
623+
let (_, batches) = join_collect_with_batch_size(
624+
test_type,
625+
left.clone(),
626+
right.clone(),
627+
on.clone(),
628+
Inner,
629+
5,
630+
)
631+
.await?;
632+
assert_batches_sorted_eq!(expected, &batches);
633+
let (_, batches) = join_collect_with_batch_size(
634+
test_type,
635+
left.clone(),
636+
right.clone(),
637+
on.clone(),
638+
Inner,
639+
7,
640+
)
641+
.await?;
642+
assert_batches_sorted_eq!(expected, &batches);
643+
}
644+
Ok(())
645+
}
646+
450647
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
451648
async fn join_left_one() -> Result<()> {
452649
for test_type in ALL_TEST_TYPE {

0 commit comments

Comments
 (0)