Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion python/sedonadb/tests/functions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import pytest
from sedonadb.testing import PostGIS, SedonaDB
import pyproj
from sedonadb.testing import geom_or_null, PostGIS, SedonaDB, val_or_null


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
Expand All @@ -26,3 +27,42 @@ def test_st_transform(eng):
"POINT (111319.490793274 111325.142866385)",
wkt_precision=9,
)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
@pytest.mark.parametrize(
("geom", "srid", "expected_srid"),
[
("POINT (1 1)", None, None),
("POINT (1 1)", 3857, 3857),
("POINT (1 1)", 0, None),
],
)
def test_st_setsrid(eng, geom, srid, expected_srid):
eng = eng.create_or_skip()
result = eng.execute_and_collect(
f"SELECT ST_SetSrid({geom_or_null(geom)}, {val_or_null(srid)})"
)
df = eng.result_to_pandas(result)
if expected_srid is None:
assert df.crs is None
else:
assert df.crs == pyproj.CRS(expected_srid)


# PostGIS does not handle String CRS input to ST_SetSrid
@pytest.mark.parametrize("eng", [SedonaDB])
@pytest.mark.parametrize(
("geom", "srid", "expected_srid"),
[
("POINT (1 1)", "EPSG:26920", 26920),
("POINT (1 1)", pyproj.CRS("EPSG:26920").to_json(), 26920),
],
)
def test_st_setsrid_sedonadb(eng, geom, srid, expected_srid):
eng = eng.create_or_skip()
result = eng.execute_and_collect(
f"SELECT ST_SetSrid({geom_or_null(geom)}, '{srid}')"
)
df = eng.result_to_pandas(result)
assert df.crs.to_epsg() == expected_srid
15 changes: 13 additions & 2 deletions rust/sedona-functions/src/st_setsrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ impl SedonaScalarKernel for STSetSRID {
if let ScalarValue::Utf8(maybe_crs) = scalar_crs.cast_to(&DataType::Utf8)? {
let new_crs = match maybe_crs {
Some(crs) => {
validate_crs(&crs, self.engine.as_ref())?;
deserialize_crs(&serde_json::Value::String(crs))?
if crs == "0" {
None
} else {
validate_crs(&crs, self.engine.as_ref())?;
deserialize_crs(&serde_json::Value::String(crs))?
}
}
None => None,
};
Expand Down Expand Up @@ -173,6 +177,7 @@ mod test {
let good_crs_scalar = ScalarValue::Utf8(Some("EPSG:4326".to_string()));
let null_crs_scalar = ScalarValue::Utf8(None);
let epsg_code_scalar = ScalarValue::Int32(Some(4326));
let unset_scalar = ScalarValue::Int32(Some(0));
let questionable_crs_scalar = ScalarValue::Utf8(Some("gazornenplat".to_string()));

// Call with a string scalar destination
Expand All @@ -193,6 +198,12 @@ mod test {
assert_eq!(return_type, wkb_lnglat);
assert_value_equal(&result, &geom_lnglat);

// Call with an integer code of 0 (should unset the output crs)
let (return_type, result) =
call_udf(&udf, geom_lnglat.clone(), unset_scalar.clone()).unwrap();
assert_eq!(return_type, WKB_GEOMETRY);
assert_value_equal(&result, &geom_arg);

// Ensure that an engine can reject a CRS if the UDF was constructed with one
let udf_with_validation: ScalarUDF =
st_set_srid_with_engine_udf(Some(Arc::new(ExtremelyUnusefulEngine {}))).into();
Expand Down