diff --git a/rust/sedona-raster-functions/benches/native-raster-functions.rs b/rust/sedona-raster-functions/benches/native-raster-functions.rs index cb4cb7305..8ed80a2e6 100644 --- a/rust/sedona-raster-functions/benches/native-raster-functions.rs +++ b/rust/sedona-raster-functions/benches/native-raster-functions.rs @@ -23,6 +23,7 @@ fn criterion_benchmark(c: &mut Criterion) { benchmark::scalar(c, &f, "native-raster", "rs_crs", Raster(64, 64)); benchmark::scalar(c, &f, "native-raster", "rs_envelope", Raster(64, 64)); benchmark::scalar(c, &f, "native-raster", "rs_height", Raster(64, 64)); + benchmark::scalar(c, &f, "native-raster", "rs_numbands", Raster(64, 64)); benchmark::scalar( c, &f, diff --git a/rust/sedona-raster-functions/src/lib.rs b/rust/sedona-raster-functions/src/lib.rs index 2d231384f..e36109724 100644 --- a/rust/sedona-raster-functions/src/lib.rs +++ b/rust/sedona-raster-functions/src/lib.rs @@ -20,6 +20,7 @@ pub mod register; pub mod rs_envelope; pub mod rs_example; pub mod rs_geotransform; +pub mod rs_numbands; pub mod rs_rastercoordinate; pub mod rs_size; pub mod rs_srid; diff --git a/rust/sedona-raster-functions/src/register.rs b/rust/sedona-raster-functions/src/register.rs index 8685b62ed..d61864eac 100644 --- a/rust/sedona-raster-functions/src/register.rs +++ b/rust/sedona-raster-functions/src/register.rs @@ -47,6 +47,7 @@ pub fn default_function_set() -> FunctionSet { crate::rs_geotransform::rs_skewy_udf, crate::rs_geotransform::rs_upperleftx_udf, crate::rs_geotransform::rs_upperlefty_udf, + crate::rs_numbands::rs_numbands_udf, crate::rs_rastercoordinate::rs_worldtorastercoord_udf, crate::rs_rastercoordinate::rs_worldtorastercoordx_udf, crate::rs_rastercoordinate::rs_worldtorastercoordy_udf, diff --git a/rust/sedona-raster-functions/src/rs_numbands.rs b/rust/sedona-raster-functions/src/rs_numbands.rs new file mode 100644 index 000000000..e7ea0e3e5 --- /dev/null +++ b/rust/sedona-raster-functions/src/rs_numbands.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use std::{sync::Arc, vec}; + +use crate::executor::RasterExecutor; +use arrow_array::builder::UInt32Builder; +use arrow_schema::DataType; +use datafusion_common::error::Result; +use datafusion_expr::{ + scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility, +}; +use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF}; +use sedona_raster::traits::RasterRef; +use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher}; + +/// RS_NumBands() scalar UDF implementation +/// +/// Returns the number of bands in the raster +pub fn rs_numbands_udf() -> SedonaScalarUDF { + SedonaScalarUDF::new( + "rs_numbands", + vec![Arc::new(RsNumBands {})], + Volatility::Immutable, + Some(rs_numbands_doc()), + ) +} + +fn rs_numbands_doc() -> Documentation { + Documentation::builder( + DOC_SECTION_OTHER, + "Returns the number of bands in the raster.".to_string(), + "RS_NumBands(raster: Raster)".to_string(), + ) + .with_argument("raster", "Raster: Input raster") + .with_sql_example("SELECT RS_NumBands(RS_Example())".to_string()) + .build() +} + +#[derive(Debug)] +struct RsNumBands {} + +impl SedonaScalarKernel for RsNumBands { + fn return_type(&self, args: &[SedonaType]) -> Result> { + let matcher = ArgMatcher::new( + vec![ArgMatcher::is_raster()], + SedonaType::Arrow(DataType::UInt32), + ); + + matcher.match_args(args) + } + + fn invoke_batch( + &self, + arg_types: &[SedonaType], + args: &[ColumnarValue], + ) -> Result { + let executor = RasterExecutor::new(arg_types, args); + let mut builder = UInt32Builder::with_capacity(executor.num_iterations()); + + executor.execute_raster_void(|_i, raster_opt| { + match raster_opt { + None => builder.append_null(), + Some(raster) => { + let num_bands = raster.bands().len() as u32; + builder.append_value(num_bands); + } + } + Ok(()) + })?; + + executor.finish(Arc::new(builder.finish())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::UInt32Array; + use datafusion_common::ScalarValue; + use datafusion_expr::ScalarUDF; + use sedona_schema::datatypes::RASTER; + use sedona_testing::compare::assert_array_equal; + use sedona_testing::rasters::generate_test_rasters; + use sedona_testing::testers::ScalarUdfTester; + + #[test] + fn udf_metadata() { + let udf: ScalarUDF = rs_numbands_udf().into(); + assert_eq!(udf.name(), "rs_numbands"); + assert!(udf.documentation().is_some()); + } + + #[test] + fn udf_numbands() { + let udf: ScalarUDF = rs_numbands_udf().into(); + let tester = ScalarUdfTester::new(udf, vec![RASTER]); + + tester.assert_return_type(DataType::UInt32); + + // Test with rasters - generate_test_rasters creates rasters with 1 band each + let rasters = generate_test_rasters(3, Some(1)).unwrap(); + let expected: Arc = + Arc::new(UInt32Array::from(vec![Some(1), None, Some(1)])); + + let result = tester.invoke_array(Arc::new(rasters)).unwrap(); + assert_array_equal(&result, &expected); + + // Test with null scalar + let result = tester.invoke_scalar(ScalarValue::Null).unwrap(); + tester.assert_scalar_result_equals(result, ScalarValue::UInt32(None)); + } +}