Skip to content

Commit 754e996

Browse files
committed
fix: Address remaining review comments on memory pool PR
- Move PoolType from sedona-cli to rust/sedona for R/Python reuse - Fix parse_size_string to handle decimal values (e.g. 1.5g, 0.5m) - Add comprehensive tests for parse_size_string - Fix SedonaFairSpillPool doc comment formatting
1 parent e70b076 commit 754e996

5 files changed

Lines changed: 83 additions & 11 deletions

File tree

rust/sedona/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub mod context;
1919
mod exec;
2020
pub mod memory_pool;
2121
mod object_storage;
22+
pub mod pool_type;
2223
pub mod random_geometry_provider;
2324
pub mod reader;
2425
pub mod record_batch_reader_provider;

rust/sedona/src/memory_pool.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ pub const DEFAULT_UNSPILLABLE_RESERVE_RATIO: f64 = 0.2;
2626
/// A [`MemoryPool`] implementation similar to DataFusion's [`datafusion::execution::memory_pool::FairSpillPool`],
2727
/// but with the following changes:
2828
///
29-
/// It implements a reservation mechanism for unspillable memory consumers. This addresses an issue
30-
/// where spillable consumers could potentially exhaust all available memory, preventing unspillable
29+
/// Spillable and non-spillable operators use logically separate portions of the memory pool,
30+
/// controlled by `unspillable_reserve_ratio`, instead of sharing a single pool as in
31+
/// DataFusion's default FairSpillPool, which can lead to the following issue:
32+
/// spillable consumers could potentially exhaust all available memory, preventing unspillable
3133
/// operations from acquiring necessary resources. This behavior is tracked in DataFusion issue
3234
/// https://github.com/apache/datafusion/issues/17334. In the context of Sedona, a typical example
3335
/// is a [`sedona_spatial_join::exec::SpatialJoinExec`] operator with an auto inserted

sedona-cli/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,5 @@ pub mod exec;
2828
pub mod functions;
2929
pub mod helper;
3030
pub mod highlighter;
31-
pub mod pool_type;
3231
pub mod print_format;
3332
pub mod print_options;

sedona-cli/src/main.rs

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ use datafusion::execution::memory_pool::{GreedyMemoryPool, MemoryPool, TrackCons
2727
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
2828
use sedona::context::SedonaContext;
2929
use sedona::memory_pool::{SedonaFairSpillPool, DEFAULT_UNSPILLABLE_RESERVE_RATIO};
30+
use sedona::pool_type::PoolType;
3031
use sedona_cli::{
3132
exec,
32-
pool_type::PoolType,
3333
print_format::PrintFormat,
3434
print_options::{MaxRows, PrintOptions},
3535
DATAFUSION_CLI_VERSION,
@@ -275,25 +275,25 @@ fn parse_size_string(size: &str, label: &str) -> Result<usize, String> {
275275
});
276276

277277
static SUFFIX_REGEX: LazyLock<regex::Regex> =
278-
LazyLock::new(|| regex::Regex::new(r"^([0-9]+)([a-z]+)?$").unwrap());
278+
LazyLock::new(|| regex::Regex::new(r"^([0-9.]+)\s*([a-z]+)?$").unwrap());
279279

280280
let lower = size.to_lowercase();
281281
if let Some(caps) = SUFFIX_REGEX.captures(&lower) {
282282
let num_str = caps.get(1).unwrap().as_str();
283283
let num = num_str
284-
.parse::<usize>()
284+
.parse::<f64>()
285285
.map_err(|_| format!("Invalid numeric value in {label} '{size}'"))?;
286286

287287
let suffix = caps.get(2).map(|m| m.as_str()).unwrap_or("b");
288288
let unit = BYTE_SUFFIXES
289289
.get(suffix)
290290
.ok_or_else(|| format!("Invalid {label} '{size}'"))?;
291-
let total_bytes = usize::try_from(unit.multiplier())
292-
.ok()
293-
.and_then(|multiplier| num.checked_mul(multiplier))
294-
.ok_or_else(|| format!("{label} '{size}' is too large"))?;
291+
let total_bytes = num * unit.multiplier() as f64;
292+
if !total_bytes.is_finite() || total_bytes > usize::MAX as f64 {
293+
return Err(format!("{label} '{size}' is too large"));
294+
}
295295

296-
Ok(total_bytes)
296+
Ok(total_bytes as usize)
297297
} else {
298298
Err(format!("Invalid {label} '{size}'"))
299299
}
@@ -314,3 +314,73 @@ fn validate_unspillable_reserve_ratio(s: &str) -> Result<f64, String> {
314314
}
315315
Ok(value)
316316
}
317+
318+
#[cfg(test)]
319+
mod tests {
320+
use super::*;
321+
322+
fn assert_conversion(input: &str, expected: Result<usize, String>) {
323+
let result = extract_memory_pool_size(input);
324+
match expected {
325+
Ok(v) => assert_eq!(result.unwrap(), v),
326+
Err(e) => assert_eq!(result.unwrap_err(), e),
327+
}
328+
}
329+
330+
#[test]
331+
fn memory_pool_size() -> Result<(), String> {
332+
// Test basic sizes without suffix, assumed to be bytes
333+
assert_conversion("5", Ok(5));
334+
assert_conversion("100", Ok(100));
335+
336+
// Test various units
337+
assert_conversion("5b", Ok(5));
338+
assert_conversion("4k", Ok(4 * 1024));
339+
assert_conversion("4kb", Ok(4 * 1024));
340+
assert_conversion("20m", Ok(20 * 1024 * 1024));
341+
assert_conversion("20mb", Ok(20 * 1024 * 1024));
342+
assert_conversion("2g", Ok(2 * 1024 * 1024 * 1024));
343+
assert_conversion("2gb", Ok(2 * 1024 * 1024 * 1024));
344+
assert_conversion("3t", Ok(3 * 1024 * 1024 * 1024 * 1024));
345+
assert_conversion("4tb", Ok(4 * 1024 * 1024 * 1024 * 1024));
346+
347+
// Test case insensitivity
348+
assert_conversion("4K", Ok(4 * 1024));
349+
assert_conversion("4KB", Ok(4 * 1024));
350+
assert_conversion("20M", Ok(20 * 1024 * 1024));
351+
assert_conversion("20MB", Ok(20 * 1024 * 1024));
352+
assert_conversion("2G", Ok(2 * 1024 * 1024 * 1024));
353+
assert_conversion("2GB", Ok(2 * 1024 * 1024 * 1024));
354+
assert_conversion("2T", Ok(2 * 1024 * 1024 * 1024 * 1024));
355+
356+
// Test decimal values
357+
assert_conversion("1.5g", Ok((1.5 * 1024.0 * 1024.0 * 1024.0) as usize));
358+
assert_conversion("0.5m", Ok((0.5 * 1024.0 * 1024.0) as usize));
359+
assert_conversion("9.5 gb", Ok((9.5 * 1024.0 * 1024.0 * 1024.0) as usize));
360+
361+
// Test with spaces between number and suffix
362+
assert_conversion("4 k", Ok(4 * 1024));
363+
assert_conversion("20 mb", Ok(20 * 1024 * 1024));
364+
365+
// Test invalid input
366+
assert_conversion(
367+
"invalid",
368+
Err("Invalid memory pool size 'invalid'".to_string()),
369+
);
370+
assert_conversion("4kbx", Err("Invalid memory pool size '4kbx'".to_string()));
371+
assert_conversion("-20mb", Err("Invalid memory pool size '-20mb'".to_string()));
372+
assert_conversion("-100", Err("Invalid memory pool size '-100'".to_string()));
373+
assert_conversion(
374+
"12k12k",
375+
Err("Invalid memory pool size '12k12k'".to_string()),
376+
);
377+
378+
// Test overflow
379+
assert_conversion(
380+
"99999999t",
381+
Err("memory pool size '99999999t' is too large".to_string()),
382+
);
383+
384+
Ok(())
385+
}
386+
}

0 commit comments

Comments
 (0)