diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 758317d3d2798..6e69c7c079c9a 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -42,6 +42,7 @@ use crate::expressions::case::literal_lookup_table::LiteralLookupTable; use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr_common::datum::compare_with_eq; +use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; use std::fmt::{Debug, Formatter}; @@ -659,7 +660,7 @@ impl CaseExpr { && body.else_expr.as_ref().unwrap().as_any().is::() { EvalMethod::ScalarOrScalar - } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() { + } else if body.when_then_expr.len() == 1 { EvalMethod::ExpressionOrExpression(body.project()?) } else { EvalMethod::NoExpression(body.project()?) @@ -961,32 +962,40 @@ impl CaseBody { let then_batch = filter_record_batch(batch, &when_filter)?; let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?; - let else_selection = not(&when_value)?; - let else_filter = create_filter(&else_selection, optimize_filter); - let else_batch = filter_record_batch(batch, &else_filter)?; - - // keep `else_expr`'s data type and return type consistent - let e = self.else_expr.as_ref().unwrap(); - let return_type = self.data_type(&batch.schema())?; - let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); - - let else_value = else_expr.evaluate(&else_batch)?; - - Ok(ColumnarValue::Array(match (then_value, else_value) { - (ColumnarValue::Array(t), ColumnarValue::Array(e)) => { - merge(&when_value, &t, &e) - } - (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => { - merge(&when_value, &t.to_scalar()?, &e) - } - (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => { - merge(&when_value, &t, &e.to_scalar()?) + match &self.else_expr { + None => { + let then_array = then_value.to_array(when_value.true_count())?; + scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array) } - (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => { - merge(&when_value, &t.to_scalar()?, &e.to_scalar()?) + Some(else_expr) => { + let else_selection = not(&when_value)?; + let else_filter = create_filter(&else_selection, optimize_filter); + let else_batch = filter_record_batch(batch, &else_filter)?; + + // keep `else_expr`'s data type and return type consistent + let return_type = self.data_type(&batch.schema())?; + let else_expr = + try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(else_expr)); + + let else_value = else_expr.evaluate(&else_batch)?; + + Ok(ColumnarValue::Array(match (then_value, else_value) { + (ColumnarValue::Array(t), ColumnarValue::Array(e)) => { + merge(&when_value, &t, &e) + } + (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => { + merge(&when_value, &t.to_scalar()?, &e) + } + (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => { + merge(&when_value, &t, &e.to_scalar()?) + } + (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => { + merge(&when_value, &t.to_scalar()?, &e.to_scalar()?) + } + }?)) } - }?)) + } } } @@ -1137,7 +1146,15 @@ impl CaseExpr { self.body.when_then_expr[0].1.evaluate(batch) } else if true_count == 0 { // All input rows are false/null, just call the 'else' expression - self.body.else_expr.as_ref().unwrap().evaluate(batch) + match &self.body.else_expr { + Some(else_expr) => else_expr.evaluate(batch), + None => { + let return_type = self.data_type(&batch.schema())?; + Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + &return_type, + )?)) + } + } } else if projected.projection.len() < batch.num_columns() { // The case expressions do not use all the columns of the input batch. // Project first to reduce time spent filtering.