Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions c/sedona-proj/src/st_transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,12 @@ pub fn configure_global_proj_engine(builder: ProjCrsEngineBuilder) -> Result<()>

/// Do something with the global thread-local PROJ engine, creating it if it has not
/// already been created.
pub(crate) fn with_global_proj_engine(
mut func: impl FnMut(&CachingCrsEngine<ProjCrsEngine>) -> Result<()>,
) -> Result<()> {
pub(crate) fn with_global_proj_engine<
R,
F: FnMut(&CachingCrsEngine<ProjCrsEngine>) -> Result<R>,
>(
mut func: F,
) -> Result<R> {
PROJ_ENGINE.with(|engine_cell| {
// If there is already an engine, use it!
if let Some(engine) = engine_cell.get() {
Expand All @@ -417,8 +420,7 @@ pub(crate) fn with_global_proj_engine(
engine_cell
.set(CachingCrsEngine::new(proj_engine))
.map_err(|_| sedona_internal_datafusion_err!("Failed to set cached PROJ transform"))?;
func(engine_cell.get().unwrap())?;
Ok(())
func(engine_cell.get().unwrap())
})
}

Expand Down
39 changes: 8 additions & 31 deletions rust/sedona-functions/src/st_srid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ use arrow_array::{
use arrow_schema::DataType;
use datafusion_common::{
cast::{as_string_view_array, as_struct_array},
exec_err, DataFusionError, Result, ScalarValue,
DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility,
};
use sedona_common::sedona_internal_err;
use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
use sedona_schema::crs::deserialize_crs;
use sedona_schema::crs::CachedCrsToSRIDMapping;
use sedona_schema::datatypes::SedonaType;
use sedona_schema::matchers::ArgMatcher;
use std::{collections::HashMap, iter::zip, sync::Arc};
use std::{iter::zip, sync::Arc};

/// ST_Srid() scalar UDF implementation
///
Expand Down Expand Up @@ -158,7 +158,7 @@ impl SedonaScalarKernel for StSridItemCrs {

let item_array = item_crs_struct_array.column(0);
let crs_string_array = as_string_view_array(item_crs_struct_array.column(1))?;
let mut batch_srids = HashMap::<String, u32>::new();
let mut crs_to_srid_mapping = CachedCrsToSRIDMapping::with_capacity(item_array.len());

if let Some(item_nulls) = item_array.nulls() {
for (is_valid, maybe_crs) in zip(item_nulls, crs_string_array) {
Expand All @@ -167,43 +167,20 @@ impl SedonaScalarKernel for StSridItemCrs {
continue;
}

append_srid(maybe_crs, &mut batch_srids, &mut builder)?;
let srid = crs_to_srid_mapping.get_srid(maybe_crs)?;
builder.append_value(srid);
}
} else {
for maybe_crs in crs_string_array {
append_srid(maybe_crs, &mut batch_srids, &mut builder)?;
let srid = crs_to_srid_mapping.get_srid(maybe_crs)?;
builder.append_value(srid);
}
}

executor.finish(Arc::new(builder.finish()))
}
}

fn append_srid(
maybe_crs: Option<&str>,
batch_srids: &mut HashMap<String, u32>,
builder: &mut UInt32Builder,
) -> Result<()> {
if let Some(crs_str) = maybe_crs {
if let Some(srid) = batch_srids.get(crs_str) {
builder.append_value(*srid);
} else if let Some(crs) = deserialize_crs(crs_str)? {
if let Some(srid) = crs.srid()? {
batch_srids.insert(crs_str.to_string(), srid);
builder.append_value(srid);
} else {
return exec_err!("Can't extract SRID from item-level CRS '{crs_str}'");
}
} else {
builder.append_value(0);
}
} else {
builder.append_value(0);
}

Ok(())
}

#[derive(Debug)]
struct StCrs {}

Expand Down
145 changes: 85 additions & 60 deletions rust/sedona-raster-functions/src/rs_srid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ use arrow_array::builder::StringBuilder;
use arrow_array::builder::UInt32Builder;
use arrow_schema::DataType;
use datafusion_common::error::Result;
use datafusion_common::DataFusionError;
use datafusion_expr::{
scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility,
};
use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
use sedona_raster::traits::RasterRef;
use sedona_schema::crs::deserialize_crs;
use sedona_schema::crs::CachedCrsToSRIDMapping;
use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};

/// RS_SRID() scalar UDF implementation
Expand Down Expand Up @@ -97,45 +96,17 @@ impl SedonaScalarKernel for RsSrid {
let executor = RasterExecutor::new(arg_types, args);
let mut builder = UInt32Builder::with_capacity(executor.num_iterations());

let mut crs_to_srid_mapping =
CachedCrsToSRIDMapping::with_capacity(executor.num_iterations());
executor.execute_raster_void(|_i, raster_opt| {
match raster_opt {
None => builder.append_null(),
Some(raster) => {
match raster.crs() {
None => {
// When no CRS is set, SRID is 0
builder.append_value(0);
}
Some(crs_str) => {
let crs = deserialize_crs(crs_str).map_err(|e| {
DataFusionError::Execution(format!(
"Failed to deserialize CRS: {e}"
))
})?;

match crs {
Some(crs_ref) => {
let srid = crs_ref.srid().map_err(|e| {
DataFusionError::Execution(format!(
"Failed to get SRID from CRS: {e}"
))
})?;

match srid {
Some(srid_val) => builder.append_value(srid_val),
None => {
return Err(DataFusionError::Execution(
"CRS has no SRID".to_string(),
))
}
}
}
None => builder.append_value(0),
}
}
}
}
}
let Some(raster) = raster_opt else {
builder.append_null();
return Ok(());
};

let maybe_crs = raster.crs_str_ref();
let srid = crs_to_srid_mapping.get_srid(maybe_crs)?;
builder.append_value(srid);
Ok(())
})?;

Expand Down Expand Up @@ -167,26 +138,14 @@ impl SedonaScalarKernel for RsCrs {
StringBuilder::with_capacity(executor.num_iterations(), preallocate_bytes);

executor.execute_raster_void(|_i, raster_opt| {
match raster_opt {
None => builder.append_null(),
Some(raster) => match raster.crs() {
None => builder.append_null(),
Some(crs_str) => {
let crs = deserialize_crs(crs_str).map_err(|e| {
DataFusionError::Execution(format!("Failed to deserialize CRS: {e}"))
})?;

let crs_string = crs
.ok_or_else(|| {
DataFusionError::Execution(
"Failed to parse non-null CRS string".to_string(),
)
})?
.to_crs_string();
builder.append_value(crs_string);
}
},
}
let Some(raster) = raster_opt else {
builder.append_null();
return Ok(());
};

// This is similar to ST_CRS: if no CRS is set, return "0"
let crs_str = raster.crs().unwrap_or("0");
builder.append_value(crs_str);
Ok(())
})?;

Expand All @@ -200,7 +159,10 @@ mod tests {
use arrow_array::{StringArray, UInt32Array};
use datafusion_common::ScalarValue;
use datafusion_expr::ScalarUDF;
use sedona_raster::builder::RasterBuilder;
use sedona_raster::traits::{BandMetadata, RasterMetadata};
use sedona_schema::datatypes::RASTER;
use sedona_schema::raster::{BandDataType, StorageType};
use sedona_testing::compare::assert_array_equal;
use sedona_testing::rasters::generate_test_rasters;
use sedona_testing::testers::ScalarUdfTester;
Expand Down Expand Up @@ -234,6 +196,34 @@ mod tests {
// Test with null scalar
let result = tester.invoke_scalar(ScalarValue::Null).unwrap();
tester.assert_scalar_result_equals(result, ScalarValue::UInt32(None));

// Test with raster missing CRS
let mut builder = RasterBuilder::new(1);
append_1x1_raster_with_crs(&mut builder, None);
let rasters = builder.finish().unwrap();

let result = tester.invoke_array(Arc::new(rasters)).unwrap();
let expected: Arc<dyn arrow_array::Array> = Arc::new(UInt32Array::from(vec![Some(0)]));
assert_array_equal(&result, &expected);
}

#[test]
fn udf_srid_missing_srid_returns_error() {
let udf: ScalarUDF = rs_srid_udf().into();
let tester = ScalarUdfTester::new(udf, vec![RASTER]);

// A PROJJSON CRS without an authority identifier should error.
let projjson_crs = "{\"type\":\"GeographicCRS\",\"name\":\"No authority id\"}";
let mut builder = RasterBuilder::new(1);
append_1x1_raster_with_crs(&mut builder, Some(projjson_crs));
let rasters = builder.finish().unwrap();

let err = tester.invoke_array(Arc::new(rasters)).unwrap_err();
assert!(
err.to_string()
.contains("Can't extract SRID from item-level CRS"),
"unexpected error: {err}"
);
}

#[test]
Expand All @@ -255,8 +245,43 @@ mod tests {
let result = tester.invoke_array(Arc::new(rasters)).unwrap();
assert_array_equal(&result, &expected);

// Test with raster missing CRS
let mut builder = RasterBuilder::new(1);
append_1x1_raster_with_crs(&mut builder, None);
let rasters = builder.finish().unwrap();

let result = tester.invoke_array(Arc::new(rasters)).unwrap();
let expected: Arc<dyn arrow_array::Array> = Arc::new(StringArray::from(vec![Some("0")]));
assert_array_equal(&result, &expected);

// Test with null scalar
let result = tester.invoke_scalar(ScalarValue::Null).unwrap();
tester.assert_scalar_result_equals(result, ScalarValue::Utf8(None));
}

fn append_1x1_raster_with_crs(builder: &mut RasterBuilder, crs: Option<&str>) {
let raster_metadata = RasterMetadata {
width: 1,
height: 1,
upperleft_x: 0.0,
upperleft_y: 0.0,
scale_x: 1.0,
scale_y: -1.0,
skew_x: 0.0,
skew_y: 0.0,
};
builder.start_raster(&raster_metadata, crs).unwrap();
builder
.start_band(BandMetadata {
datatype: BandDataType::UInt8,
nodata_value: None,
storage_type: StorageType::InDb,
outdb_url: None,
outdb_band_id: None,
})
.unwrap();
builder.band_data_writer().append_value([0u8]);
builder.finish_band().unwrap();
builder.finish_raster().unwrap();
}
}
16 changes: 9 additions & 7 deletions rust/sedona-raster/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ impl MetadataRef for RasterMetadata {
}
}

//

/// Implementation of MetadataRef for Arrow StructArray
struct MetadataRefImpl<'a> {
width_array: &'a UInt64Array,
Expand Down Expand Up @@ -348,6 +346,14 @@ impl<'a> RasterRefImpl<'a> {
bands,
}
}

pub fn crs_str_ref(&self) -> Option<&'a str> {
if self.crs.is_null(self.bands.raster_index) {
None
} else {
Some(self.crs.value(self.bands.raster_index))
}
}
}

impl<'a> RasterRef for RasterRefImpl<'a> {
Expand All @@ -358,11 +364,7 @@ impl<'a> RasterRef for RasterRefImpl<'a> {

#[inline(always)]
fn crs(&self) -> Option<&str> {
if self.crs.is_null(self.bands.raster_index) {
None
} else {
Some(self.crs.value(self.bands.raster_index))
}
self.crs_str_ref()
}

#[inline(always)]
Expand Down
Loading