diff --git a/c/sedona-proj/src/st_transform.rs b/c/sedona-proj/src/st_transform.rs index cefcce84d..bc7b17c1d 100644 --- a/c/sedona-proj/src/st_transform.rs +++ b/c/sedona-proj/src/st_transform.rs @@ -20,7 +20,7 @@ use arrow_schema::DataType; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; use geo_traits::to_geo::ToGeoGeometry; -use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel}; +use sedona_expr::scalar_udf::{ArgMatcher, ScalarKernelRef, SedonaScalarKernel}; use sedona_functions::executor::WkbExecutor; use sedona_geometry::transform::{transform, CachingCrsEngine, CrsEngine, CrsTransform}; use sedona_geometry::wkb_factory::WKB_MIN_PROBABLE_BYTES; @@ -135,7 +135,9 @@ fn define_arg_indexes(arg_types: &[SedonaType], indexes: &mut TransformArgIndexe indexes.first_crs = 1; for (i, arg_type) in arg_types.iter().enumerate().skip(2) { - if *arg_type == SedonaType::Arrow(DataType::Utf8) { + if ArgMatcher::is_numeric().match_type(arg_type) + || ArgMatcher::is_string().match_type(arg_type) + { indexes.second_crs = Some(i); } else if *arg_type == SedonaType::Arrow(DataType::Boolean) { indexes.lenient = Some(i); @@ -157,19 +159,40 @@ impl SedonaScalarKernel for STTransform { let mut indexes = TransformArgIndexes::new(); define_arg_indexes(arg_types, &mut indexes); - let to_crs_opt = if let Some(second_crs_index) = indexes.second_crs { + if (arg_types.len() < 2) || (arg_types.len() > 4) { + return Ok(None); + } + + if !ArgMatcher::is_geometry_or_geography().match_type(&arg_types[indexes.wkb]) { + return Ok(None); + } + + if !(ArgMatcher::is_numeric().match_type(&arg_types[indexes.first_crs]) + || ArgMatcher::is_string().match_type(&arg_types[indexes.first_crs])) + { + return Ok(None); + } + + let scalar_arg_opt = if let Some(second_crs_index) = indexes.second_crs { scalar_args.get(second_crs_index).unwrap() } else { scalar_args.get(indexes.first_crs).unwrap() }; - match to_crs_opt { - Some(ScalarValue::Utf8(Some(to_crs))) => { + let crs_str_opt = if let Some(scalar_crs) = scalar_arg_opt { + to_crs_str(scalar_crs) + } else { + None + }; + + // If there is no CRS argument, we cannot determine the return type. + match crs_str_opt { + Some(to_crs) => { let val = serde_json::Value::String(to_crs.to_string()); let crs = deserialize_crs(&val)?; Ok(Some(SedonaType::Wkb(Edges::Planar, crs))) } - _ => Ok(Some(SedonaType::Wkb(Edges::Planar, None))), + _ => Ok(None), } } @@ -187,16 +210,14 @@ impl SedonaScalarKernel for STTransform { let mut indexes = TransformArgIndexes::new(); define_arg_indexes(arg_types, &mut indexes); - let first_crs = get_scalar_str(args, indexes.first_crs).ok_or_else(|| { - DataFusionError::Execution("First argument must be a scalar string".into()) - })?; + let first_crs = get_crs_str(args, indexes.first_crs).unwrap(); let lenient = indexes .lenient .is_some_and(|i| get_scalar_bool(args, i).unwrap_or(false)); let second_crs = if let Some(second_crs_index) = indexes.second_crs { - get_scalar_str(args, second_crs_index) + get_crs_str(args, second_crs_index) } else { None }; @@ -270,12 +291,23 @@ fn parse_source_crs(source_type: &SedonaType) -> Result> { } } -fn get_scalar_str(args: &[ColumnarValue], index: usize) -> Option { - if let Some(ColumnarValue::Scalar(ScalarValue::Utf8(opt_str))) = args.get(index) { - opt_str.clone() - } else { - None +fn to_crs_str(scalar_arg: &ScalarValue) -> Option { + if let Ok(ScalarValue::Utf8(Some(crs))) = scalar_arg.cast_to(&DataType::Utf8) { + if crs.chars().all(|c| c.is_ascii_digit()) { + return Some(format!("EPSG:{crs}")); + } else { + return Some(crs); + } + } + + None +} + +fn get_crs_str(args: &[ColumnarValue], index: usize) -> Option { + if let ColumnarValue::Scalar(scalar_crs) = &args[index] { + return to_crs_str(scalar_crs); } + None } fn get_scalar_bool(args: &[ColumnarValue], index: usize) -> Option { @@ -303,6 +335,64 @@ mod tests { const NAD83ZONE6PROJ: &str = "EPSG:2230"; const WGS84: &str = "EPSG:4326"; + #[rstest] + fn invalid_arg_checks() { + let udf: SedonaScalarUDF = + SedonaScalarUDF::from_kernel("st_transform", st_transform_impl()); + + // No args + let result = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &[], + scalar_arguments: &[], + }); + assert!(result.is_err() && result.unwrap_err().to_string().contains("No kernel matching arguments")); + + // Too many args + let arg_types = [ + WKB_GEOMETRY, + SedonaType::Arrow(DataType::Utf8), + SedonaType::Arrow(DataType::Utf8), + SedonaType::Arrow(DataType::Boolean), + SedonaType::Arrow(DataType::Int32), + ]; + let arg_fields: Vec> = arg_types + .iter() + .map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap())) + .collect(); + let result = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &[None, None, None, None, None], + }); + assert!(result.is_err() && result.unwrap_err().to_string().contains("No kernel matching arguments")); + + // First arg not geometry + let arg_types = [ + SedonaType::Arrow(DataType::Utf8), + SedonaType::Arrow(DataType::Utf8), + ]; + let arg_fields: Vec> = arg_types + .iter() + .map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap())) + .collect(); + let result = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &[None, None], + }); + assert!(result.is_err() && result.unwrap_err().to_string().contains("No kernel matching arguments")); + + // Second arg not string or numeric + let arg_types = [WKB_GEOMETRY, SedonaType::Arrow(DataType::Boolean)]; + let arg_fields: Vec> = arg_types + .iter() + .map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap())) + .collect(); + let result = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &[None, None], + }); + assert!(result.is_err() && result.unwrap_err().to_string().contains("No kernel matching arguments")); + } + #[rstest] fn test_invoke_batch_with_geo_crs() { // From-CRS pulled from sedona type @@ -329,6 +419,32 @@ mod tests { ); } + #[rstest] + fn test_invoke_with_srids() { + // Use an integer SRID for the to CRS + let arg_types = [ + SedonaType::Wkb(Edges::Planar, lnglat()), + SedonaType::Arrow(DataType::UInt32), + ]; + + let wkb = create_array(&[None, Some("POINT (79.3871 43.6426)")], &arg_types[0]); + + let scalar_args = vec![ScalarValue::UInt32(Some(2230))]; + + let expected = create_array_value( + &[None, Some("POINT (-21508577.363421552 34067918.06097863)")], + &SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ)), + ); + + let (result_type, result_col) = + invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap(); + assert_value_equal(&result_col, &expected); + assert_eq!( + result_type, + SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ)) + ); + } + #[rstest] fn test_invoke_batch_with_lenient() { let arg_types = [ @@ -372,7 +488,7 @@ mod tests { } #[rstest] - fn test_invoke_batch_with_string_source() { + fn test_invoke_batch_with_source_arg() { let arg_types = [ WKB_GEOMETRY, SedonaType::Arrow(DataType::Utf8), @@ -392,6 +508,26 @@ mod tests { &SedonaType::Wkb(Edges::Planar, Some(get_crs(NAD83ZONE6PROJ).unwrap())), ); + let (result_type, result_col) = + invoke_udf_test(wkb.clone(), scalar_args, arg_types.to_vec()).unwrap(); + assert_value_equal(&result_col, &expected); + assert_eq!( + result_type, + SedonaType::Wkb(Edges::Planar, Some(get_crs(NAD83ZONE6PROJ).unwrap())) + ); + + // Test with integer SRIDs + let arg_types = [ + WKB_GEOMETRY, + SedonaType::Arrow(DataType::Int32), + SedonaType::Arrow(DataType::Int32), + ]; + + let scalar_args = vec![ + ScalarValue::Int32(Some(4326)), + ScalarValue::Int32(Some(2230)), + ]; + let (result_type, result_col) = invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap(); assert_value_equal(&result_col, &expected);