|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | +use std::{sync::Arc, vec}; |
| 18 | + |
| 19 | +use crate::executor::RasterExecutor; |
| 20 | +use arrow_array::builder::StringBuilder; |
| 21 | +use arrow_array::cast::AsArray; |
| 22 | +use arrow_array::Array; |
| 23 | +use arrow_schema::DataType; |
| 24 | +use datafusion_common::error::Result; |
| 25 | +use datafusion_common::DataFusionError; |
| 26 | +use datafusion_expr::{ |
| 27 | + scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility, |
| 28 | +}; |
| 29 | +use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF}; |
| 30 | +use sedona_raster::traits::RasterRef; |
| 31 | +use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher}; |
| 32 | + |
| 33 | +/// RS_GeoReference() scalar UDF implementation |
| 34 | +/// |
| 35 | +/// Returns the georeference metadata of raster as a string in GDAL or ESRI format |
| 36 | +pub fn rs_georeference_udf() -> SedonaScalarUDF { |
| 37 | + SedonaScalarUDF::new( |
| 38 | + "rs_georeference", |
| 39 | + vec![ |
| 40 | + Arc::new(RsGeoReferenceOneArg {}), |
| 41 | + Arc::new(RsGeoReferenceTwoArg {}), |
| 42 | + ], |
| 43 | + Volatility::Immutable, |
| 44 | + Some(rs_georeference_doc()), |
| 45 | + ) |
| 46 | +} |
| 47 | + |
| 48 | +fn rs_georeference_doc() -> Documentation { |
| 49 | + Documentation::builder( |
| 50 | + DOC_SECTION_OTHER, |
| 51 | + "Returns the georeference metadata of raster as a string in GDAL or ESRI format as commonly seen in a world file. Default is GDAL if not specified. Both formats output six lines: scalex, skewy, skewx, scaley, upperleftx, upperlefty. In GDAL format the upper-left coordinates refer to the corner of the upper-left pixel, while in ESRI format they are shifted to the center of the upper-left pixel.".to_string(), |
| 52 | + "RS_GeoReference(raster: Raster, format: String = 'GDAL')".to_string(), |
| 53 | + ) |
| 54 | + .with_argument("raster", "Raster: Input raster") |
| 55 | + .with_argument("format", "String: Output format, either 'GDAL' (default) or 'ESRI'. GDAL reports the upper-left corner of the upper-left pixel; ESRI shifts the coordinates to the center of the upper-left pixel.") |
| 56 | + .with_sql_example("SELECT RS_GeoReference(RS_Example())".to_string()) |
| 57 | + .build() |
| 58 | +} |
| 59 | + |
| 60 | +/// Format type for GeoReference output as commonly seen in a |
| 61 | +/// [world file](https://en.wikipedia.org/wiki/World_file). |
| 62 | +/// |
| 63 | +/// Both formats output six lines: scalex, skewy, skewx, scaley, upperleftx, upperlefty. |
| 64 | +/// The difference is how the upper-left coordinate is reported: |
| 65 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 66 | +enum GeoReferenceFormat { |
| 67 | + /// GDAL format: upperleftx and upperlefty are the coordinates of the upper-left corner |
| 68 | + /// of the upper-left pixel. |
| 69 | + Gdal, |
| 70 | + /// ESRI format: upperleftx and upperlefty are shifted to the center of the upper-left |
| 71 | + /// pixel, i.e. `upperleftx + scalex * 0.5` and `upperlefty + scaley * 0.5`. |
| 72 | + Esri, |
| 73 | +} |
| 74 | + |
| 75 | +impl GeoReferenceFormat { |
| 76 | + fn from_str(s: &str) -> Result<Self> { |
| 77 | + match s.to_uppercase().as_str() { |
| 78 | + "GDAL" => Ok(GeoReferenceFormat::Gdal), |
| 79 | + "ESRI" => Ok(GeoReferenceFormat::Esri), |
| 80 | + _ => Err(DataFusionError::Execution(format!( |
| 81 | + "Invalid GeoReference format '{}'. Supported formats are 'GDAL' and 'ESRI'.", |
| 82 | + s |
| 83 | + ))), |
| 84 | + } |
| 85 | + } |
| 86 | +} |
| 87 | + |
| 88 | +/// Estimated bytes per georeference string for StringBuilder preallocation. |
| 89 | +/// Output is 6 lines of `{:.10}` formatted f64 values separated by newlines. |
| 90 | +/// Each value is at most ~20 bytes (e.g. "-12345678.1234567890"), giving |
| 91 | +/// 6 * 20 + 5 newlines = 125 bytes. |
| 92 | +const PREALLOC_BYTES_PER_GEOREF: usize = 125; |
| 93 | + |
| 94 | +/// One-argument kernel: RS_GeoReference(raster) - uses GDAL format by default |
| 95 | +#[derive(Debug)] |
| 96 | +struct RsGeoReferenceOneArg {} |
| 97 | + |
| 98 | +impl SedonaScalarKernel for RsGeoReferenceOneArg { |
| 99 | + fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> { |
| 100 | + let matcher = ArgMatcher::new( |
| 101 | + vec![ArgMatcher::is_raster()], |
| 102 | + SedonaType::Arrow(DataType::Utf8), |
| 103 | + ); |
| 104 | + matcher.match_args(args) |
| 105 | + } |
| 106 | + |
| 107 | + fn invoke_batch( |
| 108 | + &self, |
| 109 | + arg_types: &[SedonaType], |
| 110 | + args: &[ColumnarValue], |
| 111 | + ) -> Result<ColumnarValue> { |
| 112 | + let executor = RasterExecutor::new(arg_types, args); |
| 113 | + |
| 114 | + let preallocate_bytes = PREALLOC_BYTES_PER_GEOREF * executor.num_iterations(); |
| 115 | + let mut builder = |
| 116 | + StringBuilder::with_capacity(executor.num_iterations(), preallocate_bytes); |
| 117 | + |
| 118 | + executor.execute_raster_void(|_i, raster_opt| { |
| 119 | + format_georeference(raster_opt, GeoReferenceFormat::Gdal, &mut builder) |
| 120 | + })?; |
| 121 | + |
| 122 | + executor.finish(Arc::new(builder.finish())) |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +/// Two-argument kernel: RS_GeoReference(raster, format) |
| 127 | +#[derive(Debug)] |
| 128 | +struct RsGeoReferenceTwoArg {} |
| 129 | + |
| 130 | +impl SedonaScalarKernel for RsGeoReferenceTwoArg { |
| 131 | + fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> { |
| 132 | + let matcher = ArgMatcher::new( |
| 133 | + vec![ArgMatcher::is_raster(), ArgMatcher::is_string()], |
| 134 | + SedonaType::Arrow(DataType::Utf8), |
| 135 | + ); |
| 136 | + matcher.match_args(args) |
| 137 | + } |
| 138 | + |
| 139 | + fn invoke_batch( |
| 140 | + &self, |
| 141 | + arg_types: &[SedonaType], |
| 142 | + args: &[ColumnarValue], |
| 143 | + ) -> Result<ColumnarValue> { |
| 144 | + let executor = RasterExecutor::new(arg_types, args); |
| 145 | + |
| 146 | + // Expand the format parameter to an array |
| 147 | + let format_array = args[1].clone().into_array(executor.num_iterations())?; |
| 148 | + let format_array = format_array.as_string::<i32>(); |
| 149 | + |
| 150 | + let preallocate_bytes = PREALLOC_BYTES_PER_GEOREF * executor.num_iterations(); |
| 151 | + let mut builder = |
| 152 | + StringBuilder::with_capacity(executor.num_iterations(), preallocate_bytes); |
| 153 | + |
| 154 | + executor.execute_raster_void(|i, raster_opt| { |
| 155 | + if format_array.is_null(i) { |
| 156 | + builder.append_null(); |
| 157 | + return Ok(()); |
| 158 | + } |
| 159 | + let format = GeoReferenceFormat::from_str(format_array.value(i))?; |
| 160 | + format_georeference(raster_opt, format, &mut builder) |
| 161 | + })?; |
| 162 | + |
| 163 | + executor.finish(Arc::new(builder.finish())) |
| 164 | + } |
| 165 | +} |
| 166 | + |
| 167 | +/// Format the georeference metadata for a raster |
| 168 | +fn format_georeference( |
| 169 | + raster_opt: Option<&sedona_raster::array::RasterRefImpl<'_>>, |
| 170 | + format: GeoReferenceFormat, |
| 171 | + builder: &mut StringBuilder, |
| 172 | +) -> Result<()> { |
| 173 | + match raster_opt { |
| 174 | + None => builder.append_null(), |
| 175 | + Some(raster) => { |
| 176 | + let metadata = raster.metadata(); |
| 177 | + let scale_x = metadata.scale_x(); |
| 178 | + let scale_y = metadata.scale_y(); |
| 179 | + let skew_x = metadata.skew_x(); |
| 180 | + let skew_y = metadata.skew_y(); |
| 181 | + let upper_left_x = metadata.upper_left_x(); |
| 182 | + let upper_left_y = metadata.upper_left_y(); |
| 183 | + |
| 184 | + let georeference = match format { |
| 185 | + GeoReferenceFormat::Gdal => { |
| 186 | + format!( |
| 187 | + "{:.10}\n{:.10}\n{:.10}\n{:.10}\n{:.10}\n{:.10}", |
| 188 | + scale_x, skew_y, skew_x, scale_y, upper_left_x, upper_left_y |
| 189 | + ) |
| 190 | + } |
| 191 | + GeoReferenceFormat::Esri => { |
| 192 | + let esri_upper_left_x = upper_left_x + scale_x * 0.5; |
| 193 | + let esri_upper_left_y = upper_left_y + scale_y * 0.5; |
| 194 | + format!( |
| 195 | + "{:.10}\n{:.10}\n{:.10}\n{:.10}\n{:.10}\n{:.10}", |
| 196 | + scale_x, skew_y, skew_x, scale_y, esri_upper_left_x, esri_upper_left_y |
| 197 | + ) |
| 198 | + } |
| 199 | + }; |
| 200 | + |
| 201 | + builder.append_value(georeference); |
| 202 | + } |
| 203 | + } |
| 204 | + Ok(()) |
| 205 | +} |
| 206 | + |
| 207 | +#[cfg(test)] |
| 208 | +mod tests { |
| 209 | + use super::*; |
| 210 | + use arrow_array::{Array, StringArray}; |
| 211 | + use datafusion_common::ScalarValue; |
| 212 | + use datafusion_expr::ScalarUDF; |
| 213 | + use sedona_schema::datatypes::RASTER; |
| 214 | + use sedona_testing::compare::assert_array_equal; |
| 215 | + use sedona_testing::rasters::generate_test_rasters; |
| 216 | + use sedona_testing::testers::ScalarUdfTester; |
| 217 | + |
| 218 | + #[test] |
| 219 | + fn udf_metadata() { |
| 220 | + let udf: ScalarUDF = rs_georeference_udf().into(); |
| 221 | + assert_eq!(udf.name(), "rs_georeference"); |
| 222 | + assert!(udf.documentation().is_some()); |
| 223 | + } |
| 224 | + |
| 225 | + #[test] |
| 226 | + fn udf_georeference_gdal_default() { |
| 227 | + let udf: ScalarUDF = rs_georeference_udf().into(); |
| 228 | + let tester = ScalarUdfTester::new(udf, vec![RASTER]); |
| 229 | + |
| 230 | + tester.assert_return_type(DataType::Utf8); |
| 231 | + |
| 232 | + // Test with rasters (one-arg, default GDAL) |
| 233 | + let rasters = generate_test_rasters(3, Some(1)).unwrap(); |
| 234 | + let result = tester.invoke_array(Arc::new(rasters.clone())).unwrap(); |
| 235 | + |
| 236 | + let expected: Arc<dyn Array> = Arc::new(StringArray::from(vec![ |
| 237 | + Some("0.1000000000\n0.0000000000\n0.0000000000\n-0.2000000000\n1.0000000000\n2.0000000000"), |
| 238 | + None, |
| 239 | + Some("0.2000000000\n0.0800000000\n0.0600000000\n-0.4000000000\n3.0000000000\n4.0000000000"), |
| 240 | + ])); |
| 241 | + assert_array_equal(&result, &expected); |
| 242 | + |
| 243 | + // Test with explicit "GDAL" or "gdal" (two-arg) |
| 244 | + for format in ["GDAL", "gdal"] { |
| 245 | + let udf: ScalarUDF = rs_georeference_udf().into(); |
| 246 | + let tester = ScalarUdfTester::new(udf, vec![RASTER, SedonaType::Arrow(DataType::Utf8)]); |
| 247 | + let result = tester |
| 248 | + .invoke_array_scalar(Arc::new(rasters.clone()), format) |
| 249 | + .unwrap(); |
| 250 | + assert_array_equal(&result, &expected); |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + #[test] |
| 255 | + fn udf_georeference_esri() { |
| 256 | + let udf: ScalarUDF = rs_georeference_udf().into(); |
| 257 | + let tester = ScalarUdfTester::new(udf, vec![RASTER, SedonaType::Arrow(DataType::Utf8)]); |
| 258 | + |
| 259 | + let expected: Arc<dyn Array> = Arc::new(StringArray::from(vec![ |
| 260 | + Some("0.1000000000\n0.0000000000\n0.0000000000\n-0.2000000000\n1.0500000000\n1.9000000000"), |
| 261 | + None, |
| 262 | + Some("0.2000000000\n0.0800000000\n0.0600000000\n-0.4000000000\n3.1000000000\n3.8000000000"), |
| 263 | + ])); |
| 264 | + |
| 265 | + for format in ["ESRI", "esri"] { |
| 266 | + let rasters = generate_test_rasters(3, Some(1)).unwrap(); |
| 267 | + let result = tester |
| 268 | + .invoke_array_scalar(Arc::new(rasters), format) |
| 269 | + .unwrap(); |
| 270 | + assert_array_equal(&result, &expected); |
| 271 | + } |
| 272 | + } |
| 273 | + |
| 274 | + #[test] |
| 275 | + fn udf_georeference_null_scalar() { |
| 276 | + let udf: ScalarUDF = rs_georeference_udf().into(); |
| 277 | + let tester = ScalarUdfTester::new(udf, vec![RASTER]); |
| 278 | + |
| 279 | + // Test with null scalar |
| 280 | + let result = tester.invoke_scalar(ScalarValue::Null).unwrap(); |
| 281 | + tester.assert_scalar_result_equals(result, ScalarValue::Utf8(None)); |
| 282 | + } |
| 283 | + |
| 284 | + #[test] |
| 285 | + fn udf_georeference_with_array_format() { |
| 286 | + let udf: ScalarUDF = rs_georeference_udf().into(); |
| 287 | + let tester = ScalarUdfTester::new(udf, vec![RASTER, SedonaType::Arrow(DataType::Utf8)]); |
| 288 | + |
| 289 | + let rasters = generate_test_rasters(4, Some(1)).unwrap(); |
| 290 | + let formats = Arc::new(StringArray::from(vec![ |
| 291 | + Some("GDAL"), // explicit GDAL |
| 292 | + Some("ESRI"), // won't matter since raster 1 is null |
| 293 | + None, // null format -> NULL output |
| 294 | + Some("ESRI"), // explicit ESRI |
| 295 | + ])); |
| 296 | + |
| 297 | + let result = tester |
| 298 | + .invoke_arrays(vec![Arc::new(rasters), formats]) |
| 299 | + .unwrap(); |
| 300 | + let expected: Arc<dyn Array> = Arc::new(StringArray::from(vec![ |
| 301 | + // explicit GDAL |
| 302 | + Some("0.1000000000\n0.0000000000\n0.0000000000\n-0.2000000000\n1.0000000000\n2.0000000000"), |
| 303 | + // null raster |
| 304 | + None, |
| 305 | + // null format -> NULL output |
| 306 | + None, |
| 307 | + // explicit ESRI |
| 308 | + Some("0.3000000000\n0.1200000000\n0.0900000000\n-0.6000000000\n4.1500000000\n4.7000000000"), |
| 309 | + ])); |
| 310 | + assert_array_equal(&result, &expected); |
| 311 | + } |
| 312 | + |
| 313 | + #[test] |
| 314 | + fn udf_georeference_invalid_format() { |
| 315 | + let udf: ScalarUDF = rs_georeference_udf().into(); |
| 316 | + let tester = ScalarUdfTester::new(udf, vec![RASTER, SedonaType::Arrow(DataType::Utf8)]); |
| 317 | + |
| 318 | + let rasters = generate_test_rasters(3, Some(1)).unwrap(); |
| 319 | + let result = tester.invoke_array_scalar(Arc::new(rasters), "INVALID"); |
| 320 | + |
| 321 | + assert!(result.is_err()); |
| 322 | + let err_msg = result.unwrap_err().to_string(); |
| 323 | + assert!( |
| 324 | + err_msg.contains("Invalid GeoReference format"), |
| 325 | + "Expected error about invalid format, got: {}", |
| 326 | + err_msg |
| 327 | + ); |
| 328 | + } |
| 329 | +} |
0 commit comments