diff --git a/c/sedona-proj/src/st_transform.rs b/c/sedona-proj/src/st_transform.rs index 63c209f9a..ccd7839f1 100644 --- a/c/sedona-proj/src/st_transform.rs +++ b/c/sedona-proj/src/st_transform.rs @@ -389,9 +389,12 @@ pub fn configure_global_proj_engine(builder: ProjCrsEngineBuilder) -> Result<()> /// Do something with the global thread-local PROJ engine, creating it if it has not /// already been created. -pub(crate) fn with_global_proj_engine( - mut func: impl FnMut(&CachingCrsEngine) -> Result<()>, -) -> Result<()> { +pub(crate) fn with_global_proj_engine< + R, + F: FnMut(&CachingCrsEngine) -> Result, +>( + mut func: F, +) -> Result { PROJ_ENGINE.with(|engine_cell| { // If there is already an engine, use it! if let Some(engine) = engine_cell.get() { @@ -417,8 +420,7 @@ pub(crate) fn with_global_proj_engine( engine_cell .set(CachingCrsEngine::new(proj_engine)) .map_err(|_| sedona_internal_datafusion_err!("Failed to set cached PROJ transform"))?; - func(engine_cell.get().unwrap())?; - Ok(()) + func(engine_cell.get().unwrap()) }) } diff --git a/rust/sedona-functions/src/st_srid.rs b/rust/sedona-functions/src/st_srid.rs index f69a21cd9..cda4ec21d 100644 --- a/rust/sedona-functions/src/st_srid.rs +++ b/rust/sedona-functions/src/st_srid.rs @@ -23,17 +23,17 @@ use arrow_array::{ use arrow_schema::DataType; use datafusion_common::{ cast::{as_string_view_array, as_struct_array}, - exec_err, DataFusionError, Result, ScalarValue, + DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility, }; use sedona_common::sedona_internal_err; use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF}; -use sedona_schema::crs::deserialize_crs; +use sedona_schema::crs::CachedCrsToSRIDMapping; use sedona_schema::datatypes::SedonaType; use sedona_schema::matchers::ArgMatcher; -use std::{collections::HashMap, iter::zip, sync::Arc}; +use std::{iter::zip, sync::Arc}; /// ST_Srid() scalar UDF implementation /// @@ -158,7 +158,7 @@ impl SedonaScalarKernel for StSridItemCrs { let item_array = item_crs_struct_array.column(0); let crs_string_array = as_string_view_array(item_crs_struct_array.column(1))?; - let mut batch_srids = HashMap::::new(); + let mut crs_to_srid_mapping = CachedCrsToSRIDMapping::with_capacity(item_array.len()); if let Some(item_nulls) = item_array.nulls() { for (is_valid, maybe_crs) in zip(item_nulls, crs_string_array) { @@ -167,11 +167,13 @@ impl SedonaScalarKernel for StSridItemCrs { continue; } - append_srid(maybe_crs, &mut batch_srids, &mut builder)?; + let srid = crs_to_srid_mapping.get_srid(maybe_crs)?; + builder.append_value(srid); } } else { for maybe_crs in crs_string_array { - append_srid(maybe_crs, &mut batch_srids, &mut builder)?; + let srid = crs_to_srid_mapping.get_srid(maybe_crs)?; + builder.append_value(srid); } } @@ -179,31 +181,6 @@ impl SedonaScalarKernel for StSridItemCrs { } } -fn append_srid( - maybe_crs: Option<&str>, - batch_srids: &mut HashMap, - builder: &mut UInt32Builder, -) -> Result<()> { - if let Some(crs_str) = maybe_crs { - if let Some(srid) = batch_srids.get(crs_str) { - builder.append_value(*srid); - } else if let Some(crs) = deserialize_crs(crs_str)? { - if let Some(srid) = crs.srid()? { - batch_srids.insert(crs_str.to_string(), srid); - builder.append_value(srid); - } else { - return exec_err!("Can't extract SRID from item-level CRS '{crs_str}'"); - } - } else { - builder.append_value(0); - } - } else { - builder.append_value(0); - } - - Ok(()) -} - #[derive(Debug)] struct StCrs {} diff --git a/rust/sedona-raster-functions/src/rs_srid.rs b/rust/sedona-raster-functions/src/rs_srid.rs index 1697fe33b..0e02ad6b9 100644 --- a/rust/sedona-raster-functions/src/rs_srid.rs +++ b/rust/sedona-raster-functions/src/rs_srid.rs @@ -21,13 +21,12 @@ use arrow_array::builder::StringBuilder; use arrow_array::builder::UInt32Builder; use arrow_schema::DataType; use datafusion_common::error::Result; -use datafusion_common::DataFusionError; use datafusion_expr::{ scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility, }; use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF}; use sedona_raster::traits::RasterRef; -use sedona_schema::crs::deserialize_crs; +use sedona_schema::crs::CachedCrsToSRIDMapping; use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher}; /// RS_SRID() scalar UDF implementation @@ -97,45 +96,17 @@ impl SedonaScalarKernel for RsSrid { let executor = RasterExecutor::new(arg_types, args); let mut builder = UInt32Builder::with_capacity(executor.num_iterations()); + let mut crs_to_srid_mapping = + CachedCrsToSRIDMapping::with_capacity(executor.num_iterations()); executor.execute_raster_void(|_i, raster_opt| { - match raster_opt { - None => builder.append_null(), - Some(raster) => { - match raster.crs() { - None => { - // When no CRS is set, SRID is 0 - builder.append_value(0); - } - Some(crs_str) => { - let crs = deserialize_crs(crs_str).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to deserialize CRS: {e}" - )) - })?; - - match crs { - Some(crs_ref) => { - let srid = crs_ref.srid().map_err(|e| { - DataFusionError::Execution(format!( - "Failed to get SRID from CRS: {e}" - )) - })?; - - match srid { - Some(srid_val) => builder.append_value(srid_val), - None => { - return Err(DataFusionError::Execution( - "CRS has no SRID".to_string(), - )) - } - } - } - None => builder.append_value(0), - } - } - } - } - } + let Some(raster) = raster_opt else { + builder.append_null(); + return Ok(()); + }; + + let maybe_crs = raster.crs_str_ref(); + let srid = crs_to_srid_mapping.get_srid(maybe_crs)?; + builder.append_value(srid); Ok(()) })?; @@ -167,26 +138,14 @@ impl SedonaScalarKernel for RsCrs { StringBuilder::with_capacity(executor.num_iterations(), preallocate_bytes); executor.execute_raster_void(|_i, raster_opt| { - match raster_opt { - None => builder.append_null(), - Some(raster) => match raster.crs() { - None => builder.append_null(), - Some(crs_str) => { - let crs = deserialize_crs(crs_str).map_err(|e| { - DataFusionError::Execution(format!("Failed to deserialize CRS: {e}")) - })?; - - let crs_string = crs - .ok_or_else(|| { - DataFusionError::Execution( - "Failed to parse non-null CRS string".to_string(), - ) - })? - .to_crs_string(); - builder.append_value(crs_string); - } - }, - } + let Some(raster) = raster_opt else { + builder.append_null(); + return Ok(()); + }; + + // This is similar to ST_CRS: if no CRS is set, return "0" + let crs_str = raster.crs().unwrap_or("0"); + builder.append_value(crs_str); Ok(()) })?; @@ -200,7 +159,10 @@ mod tests { use arrow_array::{StringArray, UInt32Array}; use datafusion_common::ScalarValue; use datafusion_expr::ScalarUDF; + use sedona_raster::builder::RasterBuilder; + use sedona_raster::traits::{BandMetadata, RasterMetadata}; use sedona_schema::datatypes::RASTER; + use sedona_schema::raster::{BandDataType, StorageType}; use sedona_testing::compare::assert_array_equal; use sedona_testing::rasters::generate_test_rasters; use sedona_testing::testers::ScalarUdfTester; @@ -234,6 +196,34 @@ mod tests { // Test with null scalar let result = tester.invoke_scalar(ScalarValue::Null).unwrap(); tester.assert_scalar_result_equals(result, ScalarValue::UInt32(None)); + + // Test with raster missing CRS + let mut builder = RasterBuilder::new(1); + append_1x1_raster_with_crs(&mut builder, None); + let rasters = builder.finish().unwrap(); + + let result = tester.invoke_array(Arc::new(rasters)).unwrap(); + let expected: Arc = Arc::new(UInt32Array::from(vec![Some(0)])); + assert_array_equal(&result, &expected); + } + + #[test] + fn udf_srid_missing_srid_returns_error() { + let udf: ScalarUDF = rs_srid_udf().into(); + let tester = ScalarUdfTester::new(udf, vec![RASTER]); + + // A PROJJSON CRS without an authority identifier should error. + let projjson_crs = "{\"type\":\"GeographicCRS\",\"name\":\"No authority id\"}"; + let mut builder = RasterBuilder::new(1); + append_1x1_raster_with_crs(&mut builder, Some(projjson_crs)); + let rasters = builder.finish().unwrap(); + + let err = tester.invoke_array(Arc::new(rasters)).unwrap_err(); + assert!( + err.to_string() + .contains("Can't extract SRID from item-level CRS"), + "unexpected error: {err}" + ); } #[test] @@ -255,8 +245,43 @@ mod tests { let result = tester.invoke_array(Arc::new(rasters)).unwrap(); assert_array_equal(&result, &expected); + // Test with raster missing CRS + let mut builder = RasterBuilder::new(1); + append_1x1_raster_with_crs(&mut builder, None); + let rasters = builder.finish().unwrap(); + + let result = tester.invoke_array(Arc::new(rasters)).unwrap(); + let expected: Arc = Arc::new(StringArray::from(vec![Some("0")])); + assert_array_equal(&result, &expected); + // Test with null scalar let result = tester.invoke_scalar(ScalarValue::Null).unwrap(); tester.assert_scalar_result_equals(result, ScalarValue::Utf8(None)); } + + fn append_1x1_raster_with_crs(builder: &mut RasterBuilder, crs: Option<&str>) { + let raster_metadata = RasterMetadata { + width: 1, + height: 1, + upperleft_x: 0.0, + upperleft_y: 0.0, + scale_x: 1.0, + scale_y: -1.0, + skew_x: 0.0, + skew_y: 0.0, + }; + builder.start_raster(&raster_metadata, crs).unwrap(); + builder + .start_band(BandMetadata { + datatype: BandDataType::UInt8, + nodata_value: None, + storage_type: StorageType::InDb, + outdb_url: None, + outdb_band_id: None, + }) + .unwrap(); + builder.band_data_writer().append_value([0u8]); + builder.finish_band().unwrap(); + builder.finish_raster().unwrap(); + } } diff --git a/rust/sedona-raster/src/array.rs b/rust/sedona-raster/src/array.rs index 7a5c167a6..07a4bce04 100644 --- a/rust/sedona-raster/src/array.rs +++ b/rust/sedona-raster/src/array.rs @@ -57,8 +57,6 @@ impl MetadataRef for RasterMetadata { } } -// - /// Implementation of MetadataRef for Arrow StructArray struct MetadataRefImpl<'a> { width_array: &'a UInt64Array, @@ -348,6 +346,14 @@ impl<'a> RasterRefImpl<'a> { bands, } } + + pub fn crs_str_ref(&self) -> Option<&'a str> { + if self.crs.is_null(self.bands.raster_index) { + None + } else { + Some(self.crs.value(self.bands.raster_index)) + } + } } impl<'a> RasterRef for RasterRefImpl<'a> { @@ -358,11 +364,7 @@ impl<'a> RasterRef for RasterRefImpl<'a> { #[inline(always)] fn crs(&self) -> Option<&str> { - if self.crs.is_null(self.bands.raster_index) { - None - } else { - Some(self.crs.value(self.bands.raster_index)) - } + self.crs_str_ref() } #[inline(always)] diff --git a/rust/sedona-schema/src/crs.rs b/rust/sedona-schema/src/crs.rs index 3813df11c..138355389 100644 --- a/rust/sedona-schema/src/crs.rs +++ b/rust/sedona-schema/src/crs.rs @@ -14,8 +14,11 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; +use datafusion_common::{ + exec_err, plan_datafusion_err, plan_err, DataFusionError, HashMap, Result, +}; use lru::LruCache; +use std::borrow::Cow; use std::cell::RefCell; use std::fmt::{Debug, Display}; use std::num::NonZeroUsize; @@ -50,7 +53,7 @@ pub fn deserialize_crs(crs_str: &str) -> Result { return Ok(cached); } - // Handle JSON strings "OGC:CRS84", "EPSG:4326", "{AUTH}:{CODE}" and "0" + // Handle JSON strings "OGC:CRS84", "EPSG:4326", "{AUTH}:{CODE}", WKT CRS strings and "0" let crs = if LngLat::is_str_lnglat(crs_str) { lnglat() } else if crs_str == "0" { @@ -78,6 +81,10 @@ pub fn deserialize_crs_from_obj(crs_value: &serde_json::Value) -> Result { } if let Some(crs_str) = crs_value.as_str() { + if crs_str.is_empty() || crs_str == "0" { + return Ok(None); + } + // Handle JSON strings "OGC:CRS84" and "EPSG:4326" if LngLat::is_str_lnglat(crs_str) { return Ok(lnglat()); @@ -101,6 +108,55 @@ pub fn deserialize_crs_from_obj(crs_value: &serde_json::Value) -> Result { Ok(Some(Arc::new(projjson))) } +/// Translating CRS into integer SRID with a cache to avoid expensive CRS deserialization. +pub struct CachedCrsToSRIDMapping { + cache: HashMap, u32>, +} + +impl Default for CachedCrsToSRIDMapping { + fn default() -> Self { + Self::new() + } +} + +impl CachedCrsToSRIDMapping { + /// Create a new CachedCrsToSRIDMapping with an empty cache. + pub fn new() -> Self { + Self { + cache: HashMap::new(), + } + } + + /// Create a new CachedCrsToSRIDMapping with an optional initial capacity for the cache. + pub fn with_capacity(capacity: usize) -> Self { + Self { + cache: HashMap::with_capacity(capacity), + } + } + + /// Get the SRID for a given CRS string, using the cache to avoid expensive deserialization where possible. + /// Returns 0 for missing CRS or CRS that don't have an SRID. Errors if the CRS string is invalid or if the + /// CRS can't be deserialized. + pub fn get_srid(&mut self, maybe_crs: Option<&str>) -> Result { + if let Some(crs_str) = maybe_crs { + if let Some(srid) = self.cache.get(crs_str) { + Ok(*srid) + } else if let Some(crs) = deserialize_crs(crs_str)? { + if let Some(srid) = crs.srid()? { + self.cache.insert(Cow::Owned(crs_str.to_string()), srid); + Ok(srid) + } else { + exec_err!("Can't extract SRID from item-level CRS '{crs_str}'") + } + } else { + Ok(0) + } + } else { + Ok(0) + } + } +} + /// Longitude/latitude CRS (WGS84) /// /// A [`Crs`] that matches EPSG:4326 or OGC:CRS84.