Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 152 additions & 16 deletions c/sedona-proj/src/st_transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use geo_traits::to_geo::ToGeoGeometry;
use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel};
use sedona_expr::scalar_udf::{ArgMatcher, ScalarKernelRef, SedonaScalarKernel};
use sedona_functions::executor::WkbExecutor;
use sedona_geometry::transform::{transform, CachingCrsEngine, CrsEngine, CrsTransform};
use sedona_geometry::wkb_factory::WKB_MIN_PROBABLE_BYTES;
Expand Down Expand Up @@ -135,7 +135,9 @@ fn define_arg_indexes(arg_types: &[SedonaType], indexes: &mut TransformArgIndexe
indexes.first_crs = 1;

for (i, arg_type) in arg_types.iter().enumerate().skip(2) {
if *arg_type == SedonaType::Arrow(DataType::Utf8) {
if ArgMatcher::is_numeric().match_type(arg_type)
|| ArgMatcher::is_string().match_type(arg_type)
{
indexes.second_crs = Some(i);
} else if *arg_type == SedonaType::Arrow(DataType::Boolean) {
indexes.lenient = Some(i);
Expand All @@ -157,19 +159,40 @@ impl SedonaScalarKernel for STTransform {
let mut indexes = TransformArgIndexes::new();
define_arg_indexes(arg_types, &mut indexes);

let to_crs_opt = if let Some(second_crs_index) = indexes.second_crs {
if (arg_types.len() < 2) || (arg_types.len() > 4) {
return Ok(None);
}

if !ArgMatcher::is_geometry_or_geography().match_type(&arg_types[indexes.wkb]) {
return Ok(None);
}

if !(ArgMatcher::is_numeric().match_type(&arg_types[indexes.first_crs])
|| ArgMatcher::is_string().match_type(&arg_types[indexes.first_crs]))
{
return Ok(None);
}

let scalar_arg_opt = if let Some(second_crs_index) = indexes.second_crs {
scalar_args.get(second_crs_index).unwrap()
} else {
scalar_args.get(indexes.first_crs).unwrap()
};

match to_crs_opt {
Some(ScalarValue::Utf8(Some(to_crs))) => {
let crs_str_opt = if let Some(scalar_crs) = scalar_arg_opt {
to_crs_str(scalar_crs)
} else {
None
};

// If there is no CRS argument, we cannot determine the return type.
match crs_str_opt {
Some(to_crs) => {
let val = serde_json::Value::String(to_crs.to_string());
let crs = deserialize_crs(&val)?;
Ok(Some(SedonaType::Wkb(Edges::Planar, crs)))
}
_ => Ok(Some(SedonaType::Wkb(Edges::Planar, None))),
_ => Ok(None),
}
}

Expand All @@ -187,16 +210,14 @@ impl SedonaScalarKernel for STTransform {
let mut indexes = TransformArgIndexes::new();
define_arg_indexes(arg_types, &mut indexes);

let first_crs = get_scalar_str(args, indexes.first_crs).ok_or_else(|| {
DataFusionError::Execution("First argument must be a scalar string".into())
})?;
let first_crs = get_crs_str(args, indexes.first_crs).unwrap();

let lenient = indexes
.lenient
.is_some_and(|i| get_scalar_bool(args, i).unwrap_or(false));

let second_crs = if let Some(second_crs_index) = indexes.second_crs {
get_scalar_str(args, second_crs_index)
get_crs_str(args, second_crs_index)
} else {
None
};
Expand Down Expand Up @@ -270,12 +291,23 @@ fn parse_source_crs(source_type: &SedonaType) -> Result<Option<String>> {
}
}

fn get_scalar_str(args: &[ColumnarValue], index: usize) -> Option<String> {
if let Some(ColumnarValue::Scalar(ScalarValue::Utf8(opt_str))) = args.get(index) {
opt_str.clone()
} else {
None
fn to_crs_str(scalar_arg: &ScalarValue) -> Option<String> {
if let Ok(ScalarValue::Utf8(Some(crs))) = scalar_arg.cast_to(&DataType::Utf8) {
if crs.chars().all(|c| c.is_ascii_digit()) {
return Some(format!("EPSG:{crs}"));
} else {
return Some(crs);
}
}

None
}

fn get_crs_str(args: &[ColumnarValue], index: usize) -> Option<String> {
if let ColumnarValue::Scalar(scalar_crs) = &args[index] {
return to_crs_str(scalar_crs);
}
None
}

fn get_scalar_bool(args: &[ColumnarValue], index: usize) -> Option<bool> {
Expand Down Expand Up @@ -303,6 +335,64 @@ mod tests {
const NAD83ZONE6PROJ: &str = "EPSG:2230";
const WGS84: &str = "EPSG:4326";

#[rstest]
fn invalid_arg_checks() {
let udf: SedonaScalarUDF =
SedonaScalarUDF::from_kernel("st_transform", st_transform_impl());

// No args
let result = udf.return_field_from_args(ReturnFieldArgs {
arg_fields: &[],
scalar_arguments: &[],
});
assert!(result.is_err() && result.unwrap_err().to_string().contains("No kernel matching arguments"));

// Too many args
let arg_types = [
WKB_GEOMETRY,
SedonaType::Arrow(DataType::Utf8),
SedonaType::Arrow(DataType::Utf8),
SedonaType::Arrow(DataType::Boolean),
SedonaType::Arrow(DataType::Int32),
];
let arg_fields: Vec<Arc<Field>> = arg_types
.iter()
.map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap()))
.collect();
let result = udf.return_field_from_args(ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &[None, None, None, None, None],
});
assert!(result.is_err() && result.unwrap_err().to_string().contains("No kernel matching arguments"));

// First arg not geometry
let arg_types = [
SedonaType::Arrow(DataType::Utf8),
SedonaType::Arrow(DataType::Utf8),
];
let arg_fields: Vec<Arc<Field>> = arg_types
.iter()
.map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap()))
.collect();
let result = udf.return_field_from_args(ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &[None, None],
});
assert!(result.is_err() && result.unwrap_err().to_string().contains("No kernel matching arguments"));

// Second arg not string or numeric
let arg_types = [WKB_GEOMETRY, SedonaType::Arrow(DataType::Boolean)];
let arg_fields: Vec<Arc<Field>> = arg_types
.iter()
.map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap()))
.collect();
let result = udf.return_field_from_args(ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &[None, None],
});
assert!(result.is_err() && result.unwrap_err().to_string().contains("No kernel matching arguments"));
}

#[rstest]
fn test_invoke_batch_with_geo_crs() {
// From-CRS pulled from sedona type
Expand All @@ -329,6 +419,32 @@ mod tests {
);
}

#[rstest]
fn test_invoke_with_srids() {
// Use an integer SRID for the to CRS
let arg_types = [
SedonaType::Wkb(Edges::Planar, lnglat()),
SedonaType::Arrow(DataType::UInt32),
];

let wkb = create_array(&[None, Some("POINT (79.3871 43.6426)")], &arg_types[0]);

let scalar_args = vec![ScalarValue::UInt32(Some(2230))];

let expected = create_array_value(
&[None, Some("POINT (-21508577.363421552 34067918.06097863)")],
&SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ)),
);

let (result_type, result_col) =
invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap();
assert_value_equal(&result_col, &expected);
assert_eq!(
result_type,
SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ))
);
}

#[rstest]
fn test_invoke_batch_with_lenient() {
let arg_types = [
Expand Down Expand Up @@ -372,7 +488,7 @@ mod tests {
}

#[rstest]
fn test_invoke_batch_with_string_source() {
fn test_invoke_batch_with_source_arg() {
let arg_types = [
WKB_GEOMETRY,
SedonaType::Arrow(DataType::Utf8),
Expand All @@ -392,6 +508,26 @@ mod tests {
&SedonaType::Wkb(Edges::Planar, Some(get_crs(NAD83ZONE6PROJ).unwrap())),
);

let (result_type, result_col) =
invoke_udf_test(wkb.clone(), scalar_args, arg_types.to_vec()).unwrap();
assert_value_equal(&result_col, &expected);
assert_eq!(
result_type,
SedonaType::Wkb(Edges::Planar, Some(get_crs(NAD83ZONE6PROJ).unwrap()))
);

// Test with integer SRIDs
let arg_types = [
WKB_GEOMETRY,
SedonaType::Arrow(DataType::Int32),
SedonaType::Arrow(DataType::Int32),
];

let scalar_args = vec![
ScalarValue::Int32(Some(4326)),
ScalarValue::Int32(Some(2230)),
];

let (result_type, result_col) =
invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap();
assert_value_equal(&result_col, &expected);
Expand Down
Loading