@@ -27,9 +27,9 @@ use datafusion::execution::memory_pool::{GreedyMemoryPool, MemoryPool, TrackCons
2727use datafusion:: execution:: runtime_env:: RuntimeEnvBuilder ;
2828use sedona:: context:: SedonaContext ;
2929use sedona:: memory_pool:: { SedonaFairSpillPool , DEFAULT_UNSPILLABLE_RESERVE_RATIO } ;
30+ use sedona:: pool_type:: PoolType ;
3031use sedona_cli:: {
3132 exec,
32- pool_type:: PoolType ,
3333 print_format:: PrintFormat ,
3434 print_options:: { MaxRows , PrintOptions } ,
3535 DATAFUSION_CLI_VERSION ,
@@ -275,25 +275,25 @@ fn parse_size_string(size: &str, label: &str) -> Result<usize, String> {
275275 } ) ;
276276
277277 static SUFFIX_REGEX : LazyLock < regex:: Regex > =
278- LazyLock :: new ( || regex:: Regex :: new ( r"^([0-9]+)([a-z]+)?$" ) . unwrap ( ) ) ;
278+ LazyLock :: new ( || regex:: Regex :: new ( r"^([0-9. ]+)\s* ([a-z]+)?$" ) . unwrap ( ) ) ;
279279
280280 let lower = size. to_lowercase ( ) ;
281281 if let Some ( caps) = SUFFIX_REGEX . captures ( & lower) {
282282 let num_str = caps. get ( 1 ) . unwrap ( ) . as_str ( ) ;
283283 let num = num_str
284- . parse :: < usize > ( )
284+ . parse :: < f64 > ( )
285285 . map_err ( |_| format ! ( "Invalid numeric value in {label} '{size}'" ) ) ?;
286286
287287 let suffix = caps. get ( 2 ) . map ( |m| m. as_str ( ) ) . unwrap_or ( "b" ) ;
288288 let unit = BYTE_SUFFIXES
289289 . get ( suffix)
290290 . ok_or_else ( || format ! ( "Invalid {label} '{size}'" ) ) ?;
291- let total_bytes = usize :: try_from ( unit. multiplier ( ) )
292- . ok ( )
293- . and_then ( |multiplier| num . checked_mul ( multiplier ) )
294- . ok_or_else ( || format ! ( "{label} '{size}' is too large" ) ) ? ;
291+ let total_bytes = num * unit. multiplier ( ) as f64 ;
292+ if !total_bytes . is_finite ( ) || total_bytes > usize :: MAX as f64 {
293+ return Err ( format ! ( "{label} '{size}' is too large" ) ) ;
294+ }
295295
296- Ok ( total_bytes)
296+ Ok ( total_bytes as usize )
297297 } else {
298298 Err ( format ! ( "Invalid {label} '{size}'" ) )
299299 }
@@ -314,3 +314,73 @@ fn validate_unspillable_reserve_ratio(s: &str) -> Result<f64, String> {
314314 }
315315 Ok ( value)
316316}
317+
318+ #[ cfg( test) ]
319+ mod tests {
320+ use super :: * ;
321+
322+ fn assert_conversion ( input : & str , expected : Result < usize , String > ) {
323+ let result = extract_memory_pool_size ( input) ;
324+ match expected {
325+ Ok ( v) => assert_eq ! ( result. unwrap( ) , v) ,
326+ Err ( e) => assert_eq ! ( result. unwrap_err( ) , e) ,
327+ }
328+ }
329+
330+ #[ test]
331+ fn memory_pool_size ( ) -> Result < ( ) , String > {
332+ // Test basic sizes without suffix, assumed to be bytes
333+ assert_conversion ( "5" , Ok ( 5 ) ) ;
334+ assert_conversion ( "100" , Ok ( 100 ) ) ;
335+
336+ // Test various units
337+ assert_conversion ( "5b" , Ok ( 5 ) ) ;
338+ assert_conversion ( "4k" , Ok ( 4 * 1024 ) ) ;
339+ assert_conversion ( "4kb" , Ok ( 4 * 1024 ) ) ;
340+ assert_conversion ( "20m" , Ok ( 20 * 1024 * 1024 ) ) ;
341+ assert_conversion ( "20mb" , Ok ( 20 * 1024 * 1024 ) ) ;
342+ assert_conversion ( "2g" , Ok ( 2 * 1024 * 1024 * 1024 ) ) ;
343+ assert_conversion ( "2gb" , Ok ( 2 * 1024 * 1024 * 1024 ) ) ;
344+ assert_conversion ( "3t" , Ok ( 3 * 1024 * 1024 * 1024 * 1024 ) ) ;
345+ assert_conversion ( "4tb" , Ok ( 4 * 1024 * 1024 * 1024 * 1024 ) ) ;
346+
347+ // Test case insensitivity
348+ assert_conversion ( "4K" , Ok ( 4 * 1024 ) ) ;
349+ assert_conversion ( "4KB" , Ok ( 4 * 1024 ) ) ;
350+ assert_conversion ( "20M" , Ok ( 20 * 1024 * 1024 ) ) ;
351+ assert_conversion ( "20MB" , Ok ( 20 * 1024 * 1024 ) ) ;
352+ assert_conversion ( "2G" , Ok ( 2 * 1024 * 1024 * 1024 ) ) ;
353+ assert_conversion ( "2GB" , Ok ( 2 * 1024 * 1024 * 1024 ) ) ;
354+ assert_conversion ( "2T" , Ok ( 2 * 1024 * 1024 * 1024 * 1024 ) ) ;
355+
356+ // Test decimal values
357+ assert_conversion ( "1.5g" , Ok ( ( 1.5 * 1024.0 * 1024.0 * 1024.0 ) as usize ) ) ;
358+ assert_conversion ( "0.5m" , Ok ( ( 0.5 * 1024.0 * 1024.0 ) as usize ) ) ;
359+ assert_conversion ( "9.5 gb" , Ok ( ( 9.5 * 1024.0 * 1024.0 * 1024.0 ) as usize ) ) ;
360+
361+ // Test with spaces between number and suffix
362+ assert_conversion ( "4 k" , Ok ( 4 * 1024 ) ) ;
363+ assert_conversion ( "20 mb" , Ok ( 20 * 1024 * 1024 ) ) ;
364+
365+ // Test invalid input
366+ assert_conversion (
367+ "invalid" ,
368+ Err ( "Invalid memory pool size 'invalid'" . to_string ( ) ) ,
369+ ) ;
370+ assert_conversion ( "4kbx" , Err ( "Invalid memory pool size '4kbx'" . to_string ( ) ) ) ;
371+ assert_conversion ( "-20mb" , Err ( "Invalid memory pool size '-20mb'" . to_string ( ) ) ) ;
372+ assert_conversion ( "-100" , Err ( "Invalid memory pool size '-100'" . to_string ( ) ) ) ;
373+ assert_conversion (
374+ "12k12k" ,
375+ Err ( "Invalid memory pool size '12k12k'" . to_string ( ) ) ,
376+ ) ;
377+
378+ // Test overflow
379+ assert_conversion (
380+ "99999999t" ,
381+ Err ( "memory pool size '99999999t' is too large" . to_string ( ) ) ,
382+ ) ;
383+
384+ Ok ( ( ) )
385+ }
386+ }
0 commit comments