Skip to content

Commit efc82a1

Browse files
committed
case_when with is_null -> fill_null
Signed-off-by: Onur Satici <onur@spiraldb.com>
1 parent 4e6e9ed commit efc82a1

1 file changed

Lines changed: 179 additions & 0 deletions

File tree

vortex-array/src/scalar_fn/fns/case_when.rs

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ use crate::scalar_fn::ChildName;
4141
use crate::scalar_fn::ExecutionArgs;
4242
use crate::scalar_fn::ScalarFnId;
4343
use crate::scalar_fn::ScalarFnVTable;
44+
use crate::scalar_fn::SimplifyCtx;
45+
use crate::scalar_fn::fns::is_not_null::IsNotNull;
46+
use crate::scalar_fn::fns::is_null::IsNull;
47+
use crate::scalar_fn::fns::literal::Literal;
4448
use crate::scalar_fn::fns::zip::zip_impl;
4549

4650
/// Options for the n-ary CaseWhen expression.
@@ -251,6 +255,49 @@ impl ScalarFnVTable for CaseWhen {
251255
merge_case_branches(branches, else_value, ctx)
252256
}
253257

258+
fn simplify(
259+
&self,
260+
options: &Self::Options,
261+
expr: &Expression,
262+
_ctx: &dyn SimplifyCtx,
263+
) -> VortexResult<Option<Expression>> {
264+
// Rewrite the COALESCE-shaped CASE WHEN into `fill_null`, which references `x`
265+
// once and lowers to a single fill kernel instead of a `zip`/merge that resolves
266+
// `x` twice (once for the `is_null` predicate, once for the value branch).
267+
//
268+
// CASE WHEN is_null(x) THEN c ELSE x END ==> fill_null(x, c)
269+
// CASE WHEN is_not_null(x) THEN x ELSE c END ==> fill_null(x, c)
270+
//
271+
// `fill_null` requires `c` to be a non-null constant: its kernel reads the fill
272+
// value via `as_constant()` and bails on a null scalar. Restricting `c` to a
273+
// non-null `Literal` keeps the rewrite both executable and semantically exact
274+
// (replacing nulls in `x` with `c` matches the CASE only when `c` is never null).
275+
if options.num_when_then_pairs != 1 || !options.has_else {
276+
return Ok(None);
277+
}
278+
279+
let when = expr.child(0);
280+
let then = expr.child(1);
281+
let els = expr.child(2);
282+
283+
// `is_null(x) ? c : x` — predicate operand and ELSE are the same `x`, fill is THEN.
284+
let (x, fill) = if when.is::<IsNull>() && when.child(0) == els {
285+
(els, then)
286+
// `is_not_null(x) ? x : c` — predicate operand and THEN are the same `x`, fill is ELSE.
287+
} else if when.is::<IsNotNull>() && when.child(0) == then {
288+
(then, els)
289+
} else {
290+
return Ok(None);
291+
};
292+
293+
match fill.as_opt::<Literal>() {
294+
Some(scalar) if !scalar.is_null() => {
295+
Ok(Some(crate::expr::fill_null(x.clone(), fill.clone())))
296+
}
297+
_ => Ok(None),
298+
}
299+
}
300+
254301
fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
255302
true
256303
}
@@ -410,12 +457,15 @@ mod tests {
410457
use crate::dtype::DType;
411458
use crate::dtype::Nullability;
412459
use crate::dtype::PType;
460+
use crate::dtype::StructFields;
413461
use crate::expr::case_when;
414462
use crate::expr::case_when_no_else;
415463
use crate::expr::col;
416464
use crate::expr::eq;
417465
use crate::expr::get_item;
418466
use crate::expr::gt;
467+
use crate::expr::is_not_null;
468+
use crate::expr::is_null;
419469
use crate::expr::lit;
420470
use crate::expr::nested_case_when;
421471
use crate::expr::root;
@@ -1193,6 +1243,135 @@ mod tests {
11931243
assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
11941244
}
11951245

1246+
// ==================== Simplify: COALESCE -> fill_null ====================
1247+
1248+
/// Builds a non-nullable struct scope whose named fields are all `Nullable(I64)`.
1249+
fn nullable_i64_scope(fields: &[&str]) -> DType {
1250+
DType::Struct(
1251+
StructFields::new(
1252+
fields.to_vec().into(),
1253+
vec![DType::Primitive(PType::I64, Nullability::Nullable); fields.len()],
1254+
),
1255+
Nullability::NonNullable,
1256+
)
1257+
}
1258+
1259+
#[test]
1260+
fn test_simplify_coalesce_is_null_rewrites_to_fill_null() -> VortexResult<()> {
1261+
// CASE WHEN is_null(x) THEN 0 ELSE x END ==> fill_null(x, 0)
1262+
let expr = case_when(is_null(col("x")), lit(0i64), col("x"));
1263+
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1264+
assert!(
1265+
optimized.to_string().starts_with("vortex.fill_null"),
1266+
"expected fill_null, got {optimized}"
1267+
);
1268+
Ok(())
1269+
}
1270+
1271+
#[test]
1272+
fn test_simplify_coalesce_is_not_null_rewrites_to_fill_null() -> VortexResult<()> {
1273+
// CASE WHEN is_not_null(x) THEN x ELSE 0 END ==> fill_null(x, 0)
1274+
let expr = case_when(is_not_null(col("x")), col("x"), lit(0i64));
1275+
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1276+
assert!(
1277+
optimized.to_string().starts_with("vortex.fill_null"),
1278+
"expected fill_null, got {optimized}"
1279+
);
1280+
Ok(())
1281+
}
1282+
1283+
#[test]
1284+
fn test_simplify_does_not_fire_when_operands_differ() -> VortexResult<()> {
1285+
// The is_null operand (x) and the ELSE (y) are different columns: not a COALESCE.
1286+
let expr = case_when(is_null(col("x")), lit(0i64), col("y"));
1287+
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "y"]))?;
1288+
let s = optimized.to_string();
1289+
assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1290+
assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1291+
Ok(())
1292+
}
1293+
1294+
#[test]
1295+
fn test_simplify_does_not_fire_for_non_constant_fill() -> VortexResult<()> {
1296+
// COALESCE(x, c) with a *column* fill: fill_null cannot consume a non-constant
1297+
// fill value, so the rewrite must not fire.
1298+
let expr = case_when(is_null(col("x")), col("c"), col("x"));
1299+
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "c"]))?;
1300+
let s = optimized.to_string();
1301+
assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1302+
assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1303+
Ok(())
1304+
}
1305+
1306+
#[test]
1307+
fn test_simplify_does_not_fire_for_null_fill() -> VortexResult<()> {
1308+
// A null fill literal would make fill_null bail and is not semantically a COALESCE.
1309+
let null_fill = lit(Scalar::null(DType::Primitive(
1310+
PType::I64,
1311+
Nullability::Nullable,
1312+
)));
1313+
let expr = case_when(is_null(col("x")), null_fill, col("x"));
1314+
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1315+
let s = optimized.to_string();
1316+
assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
1317+
assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
1318+
Ok(())
1319+
}
1320+
1321+
#[test]
1322+
fn test_simplify_does_not_fire_without_else() -> VortexResult<()> {
1323+
let expr = case_when_no_else(is_null(col("x")), lit(0i64));
1324+
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1325+
assert!(
1326+
!optimized.to_string().contains("fill_null"),
1327+
"must not rewrite a no-ELSE case_when, got {optimized}"
1328+
);
1329+
Ok(())
1330+
}
1331+
1332+
#[test]
1333+
fn test_simplify_does_not_fire_for_multi_pair() -> VortexResult<()> {
1334+
let expr = nested_case_when(
1335+
vec![
1336+
(is_null(col("x")), lit(0i64)),
1337+
(gt(col("x"), lit(5i64)), lit(1i64)),
1338+
],
1339+
Some(col("x")),
1340+
);
1341+
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
1342+
assert!(
1343+
!optimized.to_string().contains("fill_null"),
1344+
"must not rewrite a multi-pair case_when, got {optimized}"
1345+
);
1346+
Ok(())
1347+
}
1348+
1349+
#[test]
1350+
fn test_simplify_semantic_equivalence() -> VortexResult<()> {
1351+
// The optimized expression must produce the same values as the original CASE WHEN.
1352+
let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
1353+
let scope = DType::Primitive(PType::I64, Nullability::Nullable);
1354+
1355+
let original = case_when(is_null(root()), lit(0i64), root());
1356+
let optimized = original.optimize_recursive(&scope)?;
1357+
assert!(
1358+
optimized.to_string().starts_with("vortex.fill_null"),
1359+
"expected fill_null, got {optimized}"
1360+
);
1361+
1362+
// Original keeps CASE WHEN's nullable result dtype; the rewrite tightens it to
1363+
// NonNullable because a non-null fill cannot leave any nulls behind. Values match.
1364+
assert_arrays_eq!(
1365+
evaluate_expr(&original, &array),
1366+
PrimitiveArray::from_option_iter([Some(1i64), Some(0), Some(3)]).into_array()
1367+
);
1368+
assert_arrays_eq!(
1369+
evaluate_expr(&optimized, &array),
1370+
buffer![1i64, 0, 3].into_array()
1371+
);
1372+
Ok(())
1373+
}
1374+
11961375
#[test]
11971376
fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
11981377
// Exercises the scalar path: alternating rows produce one slice per row (no runs),

0 commit comments

Comments
 (0)