@@ -142,27 +142,13 @@ mod tests {
142142 expected_core : & [ Option < i32 > ] ,
143143 expected_shredded : & [ Option < i32 > ] ,
144144 ) -> VortexResult < ( ) > {
145- assert_eq ! ( array. len ( ) , expected_core. len ( ) ) ;
145+ assert_variant_core_rows ( array, expected_core) ? ;
146146 assert_eq ! ( array. len( ) , expected_shredded. len( ) ) ;
147147
148- let mut ctx = LEGACY_SESSION . create_execution_ctx ( ) ;
149- for ( idx, expected) in expected_core. iter ( ) . enumerate ( ) {
150- let scalar = array. core_storage ( ) . execute_scalar ( idx, & mut ctx) ?;
151- let variant = scalar. as_variant ( ) ;
152- match expected {
153- Some ( expected) => {
154- let value = variant
155- . value ( )
156- . ok_or_else ( || vortex_err ! ( "expected non-null variant row" ) ) ?;
157- assert_eq ! ( value. as_primitive( ) . typed_value:: <i32 >( ) , Some ( * expected) ) ;
158- }
159- None => assert ! ( variant. is_null( ) ) ,
160- }
161- }
162-
163148 let shredded = array
164149 . shredded ( )
165150 . ok_or_else ( || vortex_err ! ( "expected shredded child" ) ) ?;
151+ let mut ctx = LEGACY_SESSION . create_execution_ctx ( ) ;
166152 let shredded = shredded. clone ( ) . execute :: < PrimitiveArray > ( & mut ctx) ?;
167153 let expected_shredded_array = if let Some ( values) = expected_shredded
168154 . iter ( )
@@ -178,6 +164,30 @@ mod tests {
178164 Ok ( ( ) )
179165 }
180166
167+ fn assert_variant_core_rows (
168+ array : & VariantArray ,
169+ expected_core : & [ Option < i32 > ] ,
170+ ) -> VortexResult < ( ) > {
171+ assert_eq ! ( array. len( ) , expected_core. len( ) ) ;
172+
173+ let mut ctx = LEGACY_SESSION . create_execution_ctx ( ) ;
174+ for ( idx, expected) in expected_core. iter ( ) . enumerate ( ) {
175+ let scalar = array. core_storage ( ) . execute_scalar ( idx, & mut ctx) ?;
176+ let variant = scalar. as_variant ( ) ;
177+ match expected {
178+ Some ( expected) => {
179+ let value = variant
180+ . value ( )
181+ . ok_or_else ( || vortex_err ! ( "expected non-null variant row" ) ) ?;
182+ assert_eq ! ( value. as_primitive( ) . typed_value:: <i32 >( ) , Some ( * expected) ) ;
183+ }
184+ None => assert ! ( variant. is_null( ) ) ,
185+ }
186+ }
187+
188+ Ok ( ( ) )
189+ }
190+
181191 #[ test]
182192 fn try_new_exposes_core_storage_without_shredded ( ) -> VortexResult < ( ) > {
183193 let core_storage = core_storage ( 2 ) ;
@@ -331,6 +341,31 @@ mod tests {
331341 )
332342 }
333343
344+ #[ test]
345+ fn mask_preserves_chunked_core_storage_validity ( ) -> VortexResult < ( ) > {
346+ let dtype = DType :: Variant ( Nullability :: Nullable ) ;
347+ let core_chunks = [ Some ( 1i32 ) , None , Some ( 3 ) , Some ( 4 ) ]
348+ . into_iter ( )
349+ . map ( |value| {
350+ let scalar = match value {
351+ Some ( value) => {
352+ Scalar :: variant ( Scalar :: primitive ( value, Nullability :: NonNullable ) )
353+ . cast ( & dtype) ?
354+ }
355+ None => Scalar :: null ( dtype. clone ( ) ) ,
356+ } ;
357+ Ok ( ConstantArray :: new ( scalar, 1 ) . into_array ( ) )
358+ } )
359+ . collect :: < VortexResult < Vec < _ > > > ( ) ?;
360+ let core_storage = ChunkedArray :: try_new ( core_chunks, dtype) ?. into_array ( ) ;
361+ let variant = VariantArray :: try_new ( core_storage, None ) ?;
362+ let mask = BoolArray :: from_iter ( [ true , true , false , true ] ) . into_array ( ) ;
363+
364+ let masked = execute_variant ( variant. into_array ( ) . mask ( mask) ?) ?;
365+
366+ assert_variant_core_rows ( & masked, & [ Some ( 1 ) , None , None , Some ( 4 ) ] )
367+ }
368+
334369 #[ test]
335370 fn variant_get_keeps_valid_shredded_rows_for_matching_dtype ( ) -> VortexResult < ( ) > {
336371 let core_storage = row_storage ( [ 1 , 2 , 3 ] ) ?;
0 commit comments