@@ -41,6 +41,10 @@ use crate::scalar_fn::ChildName;
4141use crate :: scalar_fn:: ExecutionArgs ;
4242use crate :: scalar_fn:: ScalarFnId ;
4343use crate :: scalar_fn:: ScalarFnVTable ;
44+ use crate :: scalar_fn:: SimplifyCtx ;
45+ use crate :: scalar_fn:: fns:: is_not_null:: IsNotNull ;
46+ use crate :: scalar_fn:: fns:: is_null:: IsNull ;
47+ use crate :: scalar_fn:: fns:: literal:: Literal ;
4448use crate :: scalar_fn:: fns:: zip:: zip_impl;
4549
4650/// Options for the n-ary CaseWhen expression.
@@ -251,6 +255,49 @@ impl ScalarFnVTable for CaseWhen {
251255 merge_case_branches ( branches, else_value, ctx)
252256 }
253257
258+ fn simplify (
259+ & self ,
260+ options : & Self :: Options ,
261+ expr : & Expression ,
262+ _ctx : & dyn SimplifyCtx ,
263+ ) -> VortexResult < Option < Expression > > {
264+ // Rewrite the COALESCE-shaped CASE WHEN into `fill_null`, which references `x`
265+ // once and lowers to a single fill kernel instead of a `zip`/merge that resolves
266+ // `x` twice (once for the `is_null` predicate, once for the value branch).
267+ //
268+ // CASE WHEN is_null(x) THEN c ELSE x END ==> fill_null(x, c)
269+ // CASE WHEN is_not_null(x) THEN x ELSE c END ==> fill_null(x, c)
270+ //
271+ // `fill_null` requires `c` to be a non-null constant: its kernel reads the fill
272+ // value via `as_constant()` and bails on a null scalar. Restricting `c` to a
273+ // non-null `Literal` keeps the rewrite both executable and semantically exact
274+ // (replacing nulls in `x` with `c` matches the CASE only when `c` is never null).
275+ if options. num_when_then_pairs != 1 || !options. has_else {
276+ return Ok ( None ) ;
277+ }
278+
279+ let when = expr. child ( 0 ) ;
280+ let then = expr. child ( 1 ) ;
281+ let els = expr. child ( 2 ) ;
282+
283+ // `is_null(x) ? c : x` — predicate operand and ELSE are the same `x`, fill is THEN.
284+ let ( x, fill) = if when. is :: < IsNull > ( ) && when. child ( 0 ) == els {
285+ ( els, then)
286+ // `is_not_null(x) ? x : c` — predicate operand and THEN are the same `x`, fill is ELSE.
287+ } else if when. is :: < IsNotNull > ( ) && when. child ( 0 ) == then {
288+ ( then, els)
289+ } else {
290+ return Ok ( None ) ;
291+ } ;
292+
293+ match fill. as_opt :: < Literal > ( ) {
294+ Some ( scalar) if !scalar. is_null ( ) => {
295+ Ok ( Some ( crate :: expr:: fill_null ( x. clone ( ) , fill. clone ( ) ) ) )
296+ }
297+ _ => Ok ( None ) ,
298+ }
299+ }
300+
254301 fn is_null_sensitive ( & self , _options : & Self :: Options ) -> bool {
255302 true
256303 }
@@ -410,12 +457,15 @@ mod tests {
410457 use crate :: dtype:: DType ;
411458 use crate :: dtype:: Nullability ;
412459 use crate :: dtype:: PType ;
460+ use crate :: dtype:: StructFields ;
413461 use crate :: expr:: case_when;
414462 use crate :: expr:: case_when_no_else;
415463 use crate :: expr:: col;
416464 use crate :: expr:: eq;
417465 use crate :: expr:: get_item;
418466 use crate :: expr:: gt;
467+ use crate :: expr:: is_not_null;
468+ use crate :: expr:: is_null;
419469 use crate :: expr:: lit;
420470 use crate :: expr:: nested_case_when;
421471 use crate :: expr:: root;
@@ -1193,6 +1243,135 @@ mod tests {
11931243 assert_arrays_eq ! ( result, buffer![ 10i32 , 20 , 0 ] . into_array( ) ) ;
11941244 }
11951245
1246+ // ==================== Simplify: COALESCE -> fill_null ====================
1247+
1248+ /// Builds a non-nullable struct scope whose named fields are all `Nullable(I64)`.
1249+ fn nullable_i64_scope ( fields : & [ & str ] ) -> DType {
1250+ DType :: Struct (
1251+ StructFields :: new (
1252+ fields. to_vec ( ) . into ( ) ,
1253+ vec ! [ DType :: Primitive ( PType :: I64 , Nullability :: Nullable ) ; fields. len( ) ] ,
1254+ ) ,
1255+ Nullability :: NonNullable ,
1256+ )
1257+ }
1258+
1259+ #[ test]
1260+ fn test_simplify_coalesce_is_null_rewrites_to_fill_null ( ) -> VortexResult < ( ) > {
1261+ // CASE WHEN is_null(x) THEN 0 ELSE x END ==> fill_null(x, 0)
1262+ let expr = case_when ( is_null ( col ( "x" ) ) , lit ( 0i64 ) , col ( "x" ) ) ;
1263+ let optimized = expr. optimize_recursive ( & nullable_i64_scope ( & [ "x" ] ) ) ?;
1264+ assert ! (
1265+ optimized. to_string( ) . starts_with( "vortex.fill_null" ) ,
1266+ "expected fill_null, got {optimized}"
1267+ ) ;
1268+ Ok ( ( ) )
1269+ }
1270+
1271+ #[ test]
1272+ fn test_simplify_coalesce_is_not_null_rewrites_to_fill_null ( ) -> VortexResult < ( ) > {
1273+ // CASE WHEN is_not_null(x) THEN x ELSE 0 END ==> fill_null(x, 0)
1274+ let expr = case_when ( is_not_null ( col ( "x" ) ) , col ( "x" ) , lit ( 0i64 ) ) ;
1275+ let optimized = expr. optimize_recursive ( & nullable_i64_scope ( & [ "x" ] ) ) ?;
1276+ assert ! (
1277+ optimized. to_string( ) . starts_with( "vortex.fill_null" ) ,
1278+ "expected fill_null, got {optimized}"
1279+ ) ;
1280+ Ok ( ( ) )
1281+ }
1282+
1283+ #[ test]
1284+ fn test_simplify_does_not_fire_when_operands_differ ( ) -> VortexResult < ( ) > {
1285+ // The is_null operand (x) and the ELSE (y) are different columns: not a COALESCE.
1286+ let expr = case_when ( is_null ( col ( "x" ) ) , lit ( 0i64 ) , col ( "y" ) ) ;
1287+ let optimized = expr. optimize_recursive ( & nullable_i64_scope ( & [ "x" , "y" ] ) ) ?;
1288+ let s = optimized. to_string ( ) ;
1289+ assert ! ( s. contains( "CASE" ) , "expected CASE WHEN to remain, got {s}" ) ;
1290+ assert ! ( !s. contains( "fill_null" ) , "must not rewrite, got {s}" ) ;
1291+ Ok ( ( ) )
1292+ }
1293+
1294+ #[ test]
1295+ fn test_simplify_does_not_fire_for_non_constant_fill ( ) -> VortexResult < ( ) > {
1296+ // COALESCE(x, c) with a *column* fill: fill_null cannot consume a non-constant
1297+ // fill value, so the rewrite must not fire.
1298+ let expr = case_when ( is_null ( col ( "x" ) ) , col ( "c" ) , col ( "x" ) ) ;
1299+ let optimized = expr. optimize_recursive ( & nullable_i64_scope ( & [ "x" , "c" ] ) ) ?;
1300+ let s = optimized. to_string ( ) ;
1301+ assert ! ( s. contains( "CASE" ) , "expected CASE WHEN to remain, got {s}" ) ;
1302+ assert ! ( !s. contains( "fill_null" ) , "must not rewrite, got {s}" ) ;
1303+ Ok ( ( ) )
1304+ }
1305+
1306+ #[ test]
1307+ fn test_simplify_does_not_fire_for_null_fill ( ) -> VortexResult < ( ) > {
1308+ // A null fill literal would make fill_null bail and is not semantically a COALESCE.
1309+ let null_fill = lit ( Scalar :: null ( DType :: Primitive (
1310+ PType :: I64 ,
1311+ Nullability :: Nullable ,
1312+ ) ) ) ;
1313+ let expr = case_when ( is_null ( col ( "x" ) ) , null_fill, col ( "x" ) ) ;
1314+ let optimized = expr. optimize_recursive ( & nullable_i64_scope ( & [ "x" ] ) ) ?;
1315+ let s = optimized. to_string ( ) ;
1316+ assert ! ( s. contains( "CASE" ) , "expected CASE WHEN to remain, got {s}" ) ;
1317+ assert ! ( !s. contains( "fill_null" ) , "must not rewrite, got {s}" ) ;
1318+ Ok ( ( ) )
1319+ }
1320+
1321+ #[ test]
1322+ fn test_simplify_does_not_fire_without_else ( ) -> VortexResult < ( ) > {
1323+ let expr = case_when_no_else ( is_null ( col ( "x" ) ) , lit ( 0i64 ) ) ;
1324+ let optimized = expr. optimize_recursive ( & nullable_i64_scope ( & [ "x" ] ) ) ?;
1325+ assert ! (
1326+ !optimized. to_string( ) . contains( "fill_null" ) ,
1327+ "must not rewrite a no-ELSE case_when, got {optimized}"
1328+ ) ;
1329+ Ok ( ( ) )
1330+ }
1331+
1332+ #[ test]
1333+ fn test_simplify_does_not_fire_for_multi_pair ( ) -> VortexResult < ( ) > {
1334+ let expr = nested_case_when (
1335+ vec ! [
1336+ ( is_null( col( "x" ) ) , lit( 0i64 ) ) ,
1337+ ( gt( col( "x" ) , lit( 5i64 ) ) , lit( 1i64 ) ) ,
1338+ ] ,
1339+ Some ( col ( "x" ) ) ,
1340+ ) ;
1341+ let optimized = expr. optimize_recursive ( & nullable_i64_scope ( & [ "x" ] ) ) ?;
1342+ assert ! (
1343+ !optimized. to_string( ) . contains( "fill_null" ) ,
1344+ "must not rewrite a multi-pair case_when, got {optimized}"
1345+ ) ;
1346+ Ok ( ( ) )
1347+ }
1348+
1349+ #[ test]
1350+ fn test_simplify_semantic_equivalence ( ) -> VortexResult < ( ) > {
1351+ // The optimized expression must produce the same values as the original CASE WHEN.
1352+ let array = PrimitiveArray :: from_option_iter ( [ Some ( 1i64 ) , None , Some ( 3 ) ] ) . into_array ( ) ;
1353+ let scope = DType :: Primitive ( PType :: I64 , Nullability :: Nullable ) ;
1354+
1355+ let original = case_when ( is_null ( root ( ) ) , lit ( 0i64 ) , root ( ) ) ;
1356+ let optimized = original. optimize_recursive ( & scope) ?;
1357+ assert ! (
1358+ optimized. to_string( ) . starts_with( "vortex.fill_null" ) ,
1359+ "expected fill_null, got {optimized}"
1360+ ) ;
1361+
1362+ // Original keeps CASE WHEN's nullable result dtype; the rewrite tightens it to
1363+ // NonNullable because a non-null fill cannot leave any nulls behind. Values match.
1364+ assert_arrays_eq ! (
1365+ evaluate_expr( & original, & array) ,
1366+ PrimitiveArray :: from_option_iter( [ Some ( 1i64 ) , Some ( 0 ) , Some ( 3 ) ] ) . into_array( )
1367+ ) ;
1368+ assert_arrays_eq ! (
1369+ evaluate_expr( & optimized, & array) ,
1370+ buffer![ 1i64 , 0 , 3 ] . into_array( )
1371+ ) ;
1372+ Ok ( ( ) )
1373+ }
1374+
11961375 #[ test]
11971376 fn test_merge_case_branches_alternating_mask ( ) -> VortexResult < ( ) > {
11981377 // Exercises the scalar path: alternating rows produce one slice per row (no runs),
0 commit comments