diff --git a/python/sedonadb/tests/functions/test_transforms.py b/python/sedonadb/tests/functions/test_transforms.py index 3a5f8cc71..1e11ecf95 100644 --- a/python/sedonadb/tests/functions/test_transforms.py +++ b/python/sedonadb/tests/functions/test_transforms.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. import pytest -from sedonadb.testing import PostGIS, SedonaDB +import pyproj +from sedonadb.testing import geom_or_null, PostGIS, SedonaDB, val_or_null @pytest.mark.parametrize("eng", [SedonaDB, PostGIS]) @@ -26,3 +27,42 @@ def test_st_transform(eng): "POINT (111319.490793274 111325.142866385)", wkt_precision=9, ) + + +@pytest.mark.parametrize("eng", [SedonaDB, PostGIS]) +@pytest.mark.parametrize( + ("geom", "srid", "expected_srid"), + [ + ("POINT (1 1)", None, None), + ("POINT (1 1)", 3857, 3857), + ("POINT (1 1)", 0, None), + ], +) +def test_st_setsrid(eng, geom, srid, expected_srid): + eng = eng.create_or_skip() + result = eng.execute_and_collect( + f"SELECT ST_SetSrid({geom_or_null(geom)}, {val_or_null(srid)})" + ) + df = eng.result_to_pandas(result) + if expected_srid is None: + assert df.crs is None + else: + assert df.crs == pyproj.CRS(expected_srid) + + +# PostGIS does not handle String CRS input to ST_SetSrid +@pytest.mark.parametrize("eng", [SedonaDB]) +@pytest.mark.parametrize( + ("geom", "srid", "expected_srid"), + [ + ("POINT (1 1)", "EPSG:26920", 26920), + ("POINT (1 1)", pyproj.CRS("EPSG:26920").to_json(), 26920), + ], +) +def test_st_setsrid_sedonadb(eng, geom, srid, expected_srid): + eng = eng.create_or_skip() + result = eng.execute_and_collect( + f"SELECT ST_SetSrid({geom_or_null(geom)}, '{srid}')" + ) + df = eng.result_to_pandas(result) + assert df.crs.to_epsg() == expected_srid diff --git a/rust/sedona-functions/src/st_setsrid.rs b/rust/sedona-functions/src/st_setsrid.rs index f559fbf35..d02a0f3f7 100644 --- a/rust/sedona-functions/src/st_setsrid.rs +++ b/rust/sedona-functions/src/st_setsrid.rs @@ -88,8 +88,12 @@ impl SedonaScalarKernel for STSetSRID { if let ScalarValue::Utf8(maybe_crs) = scalar_crs.cast_to(&DataType::Utf8)? { let new_crs = match maybe_crs { Some(crs) => { - validate_crs(&crs, self.engine.as_ref())?; - deserialize_crs(&serde_json::Value::String(crs))? + if crs == "0" { + None + } else { + validate_crs(&crs, self.engine.as_ref())?; + deserialize_crs(&serde_json::Value::String(crs))? + } } None => None, }; @@ -173,6 +177,7 @@ mod test { let good_crs_scalar = ScalarValue::Utf8(Some("EPSG:4326".to_string())); let null_crs_scalar = ScalarValue::Utf8(None); let epsg_code_scalar = ScalarValue::Int32(Some(4326)); + let unset_scalar = ScalarValue::Int32(Some(0)); let questionable_crs_scalar = ScalarValue::Utf8(Some("gazornenplat".to_string())); // Call with a string scalar destination @@ -193,6 +198,12 @@ mod test { assert_eq!(return_type, wkb_lnglat); assert_value_equal(&result, &geom_lnglat); + // Call with an integer code of 0 (should unset the output crs) + let (return_type, result) = + call_udf(&udf, geom_lnglat.clone(), unset_scalar.clone()).unwrap(); + assert_eq!(return_type, WKB_GEOMETRY); + assert_value_equal(&result, &geom_arg); + // Ensure that an engine can reject a CRS if the UDF was constructed with one let udf_with_validation: ScalarUDF = st_set_srid_with_engine_udf(Some(Arc::new(ExtremelyUnusefulEngine {}))).into();