Skip to content

Commit b38ad85

Browse files
feat(rust/sedona-functions): make ST_Translate accept deltaZ arg (#524)
1 parent 9b6e2e7 commit b38ad85

2 files changed

Lines changed: 283 additions & 22 deletions

File tree

python/sedonadb/tests/functions/test_transforms.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,71 @@ def test_st_translate(eng, geom, dx, dy, expected):
192192
f"SELECT ST_Translate({geom_or_null(geom)}, {val_or_null(dx)}, {val_or_null(dy)})",
193193
expected,
194194
)
195+
196+
197+
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
198+
@pytest.mark.parametrize(
199+
("geom", "dx", "dy", "dz", "expected"),
200+
[
201+
# Nulls
202+
(None, None, None, None, None),
203+
(None, 1.0, 2.0, 3.0, None),
204+
("POINT Z (0 1 2)", None, 2.0, 3.0, None),
205+
("POINT Z (0 1 2)", 1.0, None, 3.0, None),
206+
("POINT Z (0 1 2)", 1.0, 2.0, None, None),
207+
("POINT Z (0 1 2)", 1.0, 2.0, 3.0, "POINT Z (1 3 5)"), # Positives
208+
("POINT Z (0 1 2)", -1.0, -2.0, -3.0, "POINT Z (-1 -1 -1)"), # Negatives
209+
("POINT Z (0 1 2)", 0.0, 0.0, 0.0, "POINT Z (0 1 2)"), # Zeroes
210+
("POINT Z (0 1 2)", 1, 2, 3, "POINT Z (1 3 5)"), # Integers
211+
("POINT (0 1)", 1.0, 2.0, 3.0, "POINT (1 3)"), # 2D
212+
("POINT M (0 1 2)", 1.0, 2.0, 3.0, "POINT M (1 3 2)"), # M
213+
("POINT ZM (0 1 2 3)", 1.0, 2.0, 3.0, "POINT ZM (1 3 5 3)"), # ZM
214+
# Not points
215+
("LINESTRING Z (0 1 2, 2 3 4)", 1.0, 2.0, 3.0, "LINESTRING Z (1 3 5, 3 5 7)"),
216+
(
217+
"POLYGON Z ((0 0 0, 1 0 2, 0 1 2, 0 0 0))",
218+
1.0,
219+
2.0,
220+
3.0,
221+
"POLYGON Z ((1 2 3, 2 2 5, 1 3 5, 1 2 3))",
222+
),
223+
("MULTIPOINT Z (0 1 2, 2 3 4)", 1.0, 2.0, 3.0, "MULTIPOINT Z (1 3 5, 3 5 7)"),
224+
(
225+
"MULTILINESTRING Z ((0 1 2, 2 3 4))",
226+
1.0,
227+
2.0,
228+
3.0,
229+
"MULTILINESTRING Z ((1 3 5, 3 5 7))",
230+
),
231+
(
232+
"MULTIPOLYGON Z (((0 0 0, 1 0 2, 0 1 2, 0 0 0)))",
233+
1.0,
234+
2.0,
235+
3.0,
236+
"MULTIPOLYGON Z (((1 2 3, 2 2 5, 1 3 5, 1 2 3)))",
237+
),
238+
(
239+
"GEOMETRYCOLLECTION Z (POINT Z (0 1 2))",
240+
1.0,
241+
2.0,
242+
3.0,
243+
"GEOMETRYCOLLECTION Z (POINT Z (1 3 5))",
244+
),
245+
# WKT output of geoarrow-c is causing this (both correctly output
246+
# empties)
247+
("POINT EMPTY", 1.0, 2.0, 3.0, "POINT (nan nan)"),
248+
("POINT Z EMPTY", 1.0, 2.0, 3.0, "POINT Z (nan nan nan)"),
249+
("LINESTRING EMPTY", 1.0, 2.0, 3.0, "LINESTRING EMPTY"),
250+
("POLYGON EMPTY", 1.0, 2.0, 3.0, "POLYGON EMPTY"),
251+
("MULTIPOINT EMPTY", 1.0, 2.0, 3.0, "MULTIPOINT EMPTY"),
252+
("MULTILINESTRING EMPTY", 1.0, 2.0, 3.0, "MULTILINESTRING EMPTY"),
253+
("MULTIPOLYGON EMPTY", 1.0, 2.0, 3.0, "MULTIPOLYGON EMPTY"),
254+
("GEOMETRYCOLLECTION EMPTY", 1.0, 2.0, 3.0, "GEOMETRYCOLLECTION EMPTY"),
255+
],
256+
)
257+
def test_st_translate_3d(eng, geom, dx, dy, dz, expected):
258+
eng = eng.create_or_skip()
259+
eng.assert_query_result(
260+
f"SELECT ST_Translate({geom_or_null(geom)}, {val_or_null(dx)}, {val_or_null(dy)}, {val_or_null(dz)})",
261+
expected,
262+
)

rust/sedona-functions/src/st_translate.rs

Lines changed: 215 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
// KIND, either express or implied. See the License for the
1515
// specific language governing permissions and limitations
1616
// under the License.
17-
use arrow_array::builder::BinaryBuilder;
17+
use arrow_array::{builder::BinaryBuilder, types::Float64Type, Array, PrimitiveArray};
1818
use arrow_schema::DataType;
1919
use datafusion_common::{cast::as_float64_array, error::Result, DataFusionError};
2020
use datafusion_expr::{
2121
scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility,
2222
};
2323

24+
use sedona_common::sedona_internal_err;
2425
use sedona_expr::{
2526
item_crs::ItemCrsKernel,
2627
scalar_udf::{SedonaScalarKernel, SedonaScalarUDF},
@@ -34,15 +35,18 @@ use sedona_schema::{
3435
datatypes::{SedonaType, WKB_GEOMETRY},
3536
matchers::ArgMatcher,
3637
};
37-
use std::{iter::zip, sync::Arc};
38+
use std::sync::Arc;
3839

3940
use crate::executor::WkbExecutor;
4041

4142
/// ST_Translate() scalar UDF
4243
pub fn st_translate_udf() -> SedonaScalarUDF {
4344
SedonaScalarUDF::new(
4445
"st_translate",
45-
ItemCrsKernel::wrap_impl(vec![Arc::new(STTranslate)]),
46+
ItemCrsKernel::wrap_impl(vec![
47+
Arc::new(STTranslate { is_3d: true }),
48+
Arc::new(STTranslate { is_3d: false }),
49+
]),
4650
Volatility::Immutable,
4751
Some(st_translate_doc()),
4852
)
@@ -62,18 +66,27 @@ fn st_translate_doc() -> Documentation {
6266
}
6367

6468
#[derive(Debug)]
65-
struct STTranslate;
69+
struct STTranslate {
70+
is_3d: bool,
71+
}
6672

6773
impl SedonaScalarKernel for STTranslate {
6874
fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
69-
let matcher = ArgMatcher::new(
75+
let matchers = if self.is_3d {
7076
vec![
7177
ArgMatcher::is_geometry(),
7278
ArgMatcher::is_numeric(),
7379
ArgMatcher::is_numeric(),
74-
],
75-
WKB_GEOMETRY,
76-
);
80+
ArgMatcher::is_numeric(),
81+
]
82+
} else {
83+
vec![
84+
ArgMatcher::is_geometry(),
85+
ArgMatcher::is_numeric(),
86+
ArgMatcher::is_numeric(),
87+
]
88+
};
89+
let matcher = ArgMatcher::new(matchers, WKB_GEOMETRY);
7790

7891
matcher.match_args(args)
7992
}
@@ -89,21 +102,40 @@ impl SedonaScalarKernel for STTranslate {
89102
WKB_MIN_PROBABLE_BYTES * executor.num_iterations(),
90103
);
91104

92-
let deltax = args[1]
93-
.cast_to(&DataType::Float64, None)?
94-
.to_array(executor.num_iterations())?;
95-
let deltay = args[2]
96-
.cast_to(&DataType::Float64, None)?
97-
.to_array(executor.num_iterations())?;
98-
let deltax_array = as_float64_array(&deltax)?;
99-
let deltay_array = as_float64_array(&deltay)?;
100-
let mut delta_iter = zip(deltax_array, deltay_array);
105+
let array_args = args[1..]
106+
.iter()
107+
.map(|arg| {
108+
arg.cast_to(&DataType::Float64, None)?
109+
.to_array(executor.num_iterations())
110+
})
111+
.collect::<Result<Vec<Arc<dyn arrow_array::Array>>>>()?;
112+
113+
let deltax_array = as_float64_array(&array_args[0])?;
114+
let deltay_array = as_float64_array(&array_args[1])?;
115+
116+
let mut deltas = if self.is_3d {
117+
if args.len() != 4 {
118+
return sedona_internal_err!("Invalid number of arguments are passed");
119+
}
120+
121+
let deltaz_array = as_float64_array(&array_args[2])?;
122+
Deltas::new(deltax_array, deltay_array, Some(deltaz_array))
123+
} else {
124+
if args.len() != 3 {
125+
return sedona_internal_err!("Invalid number of arguments are passed");
126+
}
127+
128+
Deltas::new(deltax_array, deltay_array, None)
129+
};
101130

102131
executor.execute_wkb_void(|maybe_wkb| {
103-
let (deltax, deltay) = delta_iter.next().unwrap();
104-
match (maybe_wkb, deltax, deltay) {
105-
(Some(wkb), Some(deltax), Some(deltay)) => {
106-
let trans = Translate { deltax, deltay };
132+
match (maybe_wkb, deltas.next().unwrap()) {
133+
(Some(wkb), Some((deltax, deltay, deltaz))) => {
134+
let trans = Translate {
135+
deltax,
136+
deltay,
137+
deltaz,
138+
};
107139
transform(wkb, &trans, &mut builder)
108140
.map_err(|e| DataFusionError::External(Box::new(e)))?;
109141
builder.append_value([]);
@@ -120,10 +152,76 @@ impl SedonaScalarKernel for STTranslate {
120152
}
121153
}
122154

155+
#[derive(Debug)]
156+
struct Deltas<'a> {
157+
index: usize,
158+
x: &'a PrimitiveArray<Float64Type>,
159+
y: &'a PrimitiveArray<Float64Type>,
160+
z: Option<&'a PrimitiveArray<Float64Type>>,
161+
no_null: bool,
162+
}
163+
164+
impl<'a> Deltas<'a> {
165+
fn new(
166+
x: &'a PrimitiveArray<Float64Type>,
167+
y: &'a PrimitiveArray<Float64Type>,
168+
z: Option<&'a PrimitiveArray<Float64Type>>,
169+
) -> Self {
170+
let no_null = x.null_count() == 0
171+
&& y.null_count() == 0
172+
&& match z {
173+
Some(z) => z.null_count() == 0,
174+
None => true,
175+
};
176+
177+
Self {
178+
index: 0,
179+
x,
180+
y,
181+
z,
182+
no_null,
183+
}
184+
}
185+
fn is_null(&self, i: usize) -> bool {
186+
if self.no_null {
187+
return false;
188+
}
189+
190+
self.x.is_null(i)
191+
|| self.y.is_null(i)
192+
|| match self.z {
193+
Some(z) => z.is_null(i),
194+
None => false,
195+
}
196+
}
197+
}
198+
199+
impl<'a> Iterator for Deltas<'a> {
200+
type Item = Option<(f64, f64, f64)>;
201+
202+
fn next(&mut self) -> Option<Self::Item> {
203+
let i = self.index;
204+
self.index += 1;
205+
206+
if self.is_null(i) {
207+
return Some(None);
208+
}
209+
210+
let x = self.x.value(i);
211+
let y = self.y.value(i);
212+
let z = match self.z {
213+
Some(z) => z.value(i),
214+
None => 0.0,
215+
};
216+
Some(Some((x, y, z)))
217+
}
218+
}
219+
123220
#[derive(Debug)]
124221
struct Translate {
125222
deltax: f64,
126223
deltay: f64,
224+
deltaz: f64,
127225
}
128226

129227
impl CrsTransform for Translate {
@@ -132,6 +230,16 @@ impl CrsTransform for Translate {
132230
coord.1 += self.deltay;
133231
Ok(())
134232
}
233+
234+
fn transform_coord_3d(
235+
&self,
236+
coord: &mut (f64, f64, f64),
237+
) -> std::result::Result<(), SedonaGeometryError> {
238+
coord.0 += self.deltax;
239+
coord.1 += self.deltay;
240+
coord.2 += self.deltaz;
241+
Ok(())
242+
}
135243
}
136244

137245
#[cfg(test)]
@@ -155,7 +263,7 @@ mod tests {
155263
}
156264

157265
#[rstest]
158-
fn udf(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] sedona_type: SedonaType) {
266+
fn udf_2d(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] sedona_type: SedonaType) {
159267
let tester = ScalarUdfTester::new(
160268
st_translate_udf().into(),
161269
vec![
@@ -225,6 +333,91 @@ mod tests {
225333
assert_array_equal(&result, &expected);
226334
}
227335

336+
#[rstest]
337+
fn udf_3d(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] sedona_type: SedonaType) {
338+
let tester = ScalarUdfTester::new(
339+
st_translate_udf().into(),
340+
vec![
341+
sedona_type.clone(),
342+
SedonaType::Arrow(DataType::Float64),
343+
SedonaType::Arrow(DataType::Float64),
344+
SedonaType::Arrow(DataType::Float64),
345+
],
346+
);
347+
tester.assert_return_type(WKB_GEOMETRY);
348+
349+
let points = create_array(
350+
&[
351+
None,
352+
Some("POINT EMPTY"),
353+
Some("POINT EMPTY"),
354+
Some("POINT EMPTY"),
355+
Some("POINT EMPTY"),
356+
Some("POINT Z EMPTY"),
357+
Some("POINT (0 1)"),
358+
Some("POINT Z (4 5 6)"),
359+
],
360+
&sedona_type,
361+
);
362+
363+
let dx = create_array!(
364+
Float64,
365+
[
366+
Some(1.0),
367+
None,
368+
Some(1.0),
369+
Some(1.0),
370+
Some(1.0),
371+
Some(1.0),
372+
Some(1.0),
373+
Some(1.0)
374+
]
375+
);
376+
let dy = create_array!(
377+
Float64,
378+
[
379+
Some(2.0),
380+
Some(2.0),
381+
None,
382+
Some(2.0),
383+
Some(2.0),
384+
Some(2.0),
385+
Some(2.0),
386+
Some(2.0)
387+
]
388+
);
389+
let dz = create_array!(
390+
Float64,
391+
[
392+
Some(3.0),
393+
Some(3.0),
394+
Some(3.0),
395+
None,
396+
Some(3.0),
397+
Some(3.0),
398+
Some(3.0),
399+
Some(3.0)
400+
]
401+
);
402+
403+
let expected = create_array(
404+
&[
405+
None,
406+
None,
407+
None,
408+
None,
409+
Some("POINT EMPTY"),
410+
Some("POINT Z EMPTY"),
411+
Some("POINT (1 3)"),
412+
Some("POINT Z (5 7 9)"),
413+
],
414+
&WKB_GEOMETRY,
415+
);
416+
417+
let result = tester.invoke_arrays(vec![points, dx, dy, dz]).unwrap();
418+
assert_array_equal(&result, &expected);
419+
}
420+
228421
#[rstest]
229422
fn udf_invoke_item_crs(#[values(WKB_GEOMETRY_ITEM_CRS.clone())] sedona_type: SedonaType) {
230423
let tester = ScalarUdfTester::new(

0 commit comments

Comments
 (0)