Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions python/sedonadb/python/sedonadb/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
from functools import cached_property
from pathlib import Path
from typing import Any, Dict, Iterable, Literal, Optional, Union
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union

from sedonadb._lib import InternalContext, configure_proj_shared
from sedonadb._options import Options
Expand Down Expand Up @@ -273,23 +273,62 @@ def read_pyogrio(
self.options,
)

def sql(self, sql: str) -> DataFrame:
def sql(
self, sql: str, *, params: Union[List, Tuple, Dict, None] = None
) -> DataFrame:
"""Create a [DataFrame][sedonadb.dataframe.DataFrame] by executing SQL

Parses a SQL string into a logical plan and returns a DataFrame
that can be used to request results or further modify the query.

Args:
sql: A single SQL statement.
params: An optional specification of parameters to bind if sql
contains placeholders (e.g., `$1` or `$my_param`). Use a
list or tuple to replace positional parameters or a dictionary
to replace named parameters. This is shorthand for
`.sql(...).with_params(...)` that is syntax-compatible with
DuckDB. See `lit()` for a list of supported Python objects.

Examples:

>>> sd = sedona.db.connect()
>>> sd.sql("SELECT ST_Point(0, 1) as geom")
<sedonadb.dataframe.DataFrame object at ...>
>>> sd.sql("SELECT ST_Point(0, 1) AS geom").show()
┌────────────┐
│ geom │
│ geometry │
╞════════════╡
│ POINT(0 1) │
└────────────┘
>>> sd.sql("SELECT ST_Point($1, $2) AS geom", params=(0, 1)).show()
┌────────────┐
│ geom │
│ geometry │
╞════════════╡
│ POINT(0 1) │
└────────────┘
>>> sd.sql("SELECT ST_Point($x, $y) AS geom", params={"x": 0, "y": 1}).show()
┌────────────┐
│ geom │
│ geometry │
╞════════════╡
│ POINT(0 1) │
└────────────┘

"""
return DataFrame(self._impl, self._impl.sql(sql), self.options)
df = DataFrame(self._impl, self._impl.sql(sql), self.options)

if params is not None:
if isinstance(params, (tuple, list)):
return df.with_params(*params)
elif isinstance(params, dict):
return df.with_params(**params)
else:
raise ValueError(
"params must be a list, tuple, or dict of scalar values"
)
else:
return df

def register_udf(self, udf: Any):
"""Register a user-defined function
Expand Down
50 changes: 46 additions & 4 deletions python/sedonadb/python/sedonadb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
# under the License.

from pathlib import Path
from typing import TYPE_CHECKING, Union, Optional, Any, Iterable, Literal
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union

from sedonadb.utility import sedona # noqa: F401


if TYPE_CHECKING:
import pandas
import geopandas
import pandas
import pyarrow


Expand Down Expand Up @@ -156,6 +155,49 @@ def count(self) -> int:
"""
return self._impl.count()

def with_params(self, *args: List[Any], **kwargs: Dict[str, Any]):
"""Replace unbound parameters in this query

For DataFrames that represent a logical plan that contains parameters (e.g.,
a SQL query of `SELECT $1 + 2`), replace parameters with concrete values.
See `lit()` for a list of supported Python objects.

Args:
args: Values to bind to positional parameters (e.g., `$1`, `$2`, `$3`)
kwargs: Values to bind to named parameters (e.g., `$my_param`). Note that
positional and named parameters cannot currently be mixed (i.e.,
parameters must be all positional or all named).

Examples:

>>> sd = sedona.db.connect()
>>> sd.sql("SELECT $1 + 2 AS c").with_params(100).show()
┌───────┐
│ c │
│ int64 │
╞═══════╡
│ 102 │
└───────┘
>>> sd.sql("SELECT $my_param + 2 AS c").with_params(my_param=100).show()
┌───────┐
│ c │
│ int64 │
╞═══════╡
│ 102 │
└───────┘

"""
from sedonadb.expr.literal import lit

positional_params = [lit(arg) for arg in args]
named_params = {k: lit(param) for k, param in kwargs.items()}

return DataFrame(
self._ctx,
self._impl.with_params(positional_params, named_params),
self._options,
)

def __arrow_c_schema__(self):
"""ArrowSchema PyCapsule interface

Expand Down Expand Up @@ -250,8 +292,8 @@ def to_arrow_table(self, schema: Any = None) -> "pyarrow.Table":
geometry: [[01010000000000000000000000000000000000F03F]]

"""
import pyarrow as pa
import geoarrow.pyarrow # noqa: F401
import pyarrow as pa

# Collects all batches into an object that exposes __arrow_c_stream__()
batches = self._impl.to_batches(schema)
Expand Down
16 changes: 16 additions & 0 deletions python/sedonadb/python/sedonadb/expr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
180 changes: 180 additions & 0 deletions python/sedonadb/python/sedonadb/expr/literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any


class Literal:
"""A Literal (constant) expression

This class represents a literal value in query that does not change
based on other information in the query or the environment. This type
of expression is also referred to as a constant. These types of
expressions are normally created with the `lit()` function or are
automatically created when passing an arbitrary Python object to
a context (e.g., parameterized SQL queries) where a literal is
required.

Literal expressions are lazily resolved such that specific contexts
have access to the underlying Python object and can resolve the
object specially (e.g., by forcing a specific Arrow type) if
required.

Args:
value: An arbitrary Python object.
"""

def __init__(self, value: Any):
self._value = value

def __arrow_c_array__(self, requested_schema=None):
resolved_lit = _resolve_arrow_lit(self._value)
return resolved_lit.__arrow_c_array__(requested_schema=requested_schema)

def __repr__(self):
return f"<Literal>\n{repr(self._value)}"


def lit(value: Any) -> Literal:
"""Create a literal (constant) expression

Creates a `Literal` object around value, or returns value if it is
already a `Literal`. This is the primary function that should be used
to wrap an arbitrary Python object a constant to prepare it as input
to any SedonaDB logical expression context (e.g., parameterized SQL).

Literal values can be created from a variety of Python objects whose
representation as a scalar constant is unambiguous. Any object that
is accepted by `pyarrow.array([...])` is supported in addition to:

- Shapely geometries become SedonaDB geometry objects.
- GeoSeries objects of length 1 become SedonaDB geometries
with CRS preserved.
- GeoDataFrame objects with a single column and single row become
SedonaDB geometries with CRS preserved.
- Pandas DataFrame objects with a single column and single row
are converted using `pa.array()`.
- SedonaDB DataFrame objects that evaluate to a single column and
row become a scalar value according to the single represented
value.

"""
if isinstance(value, Literal):
return value
else:
return Literal(value)


def _resolve_arrow_lit(obj: Any):
qualified_name = _qualified_type_name(obj)
if qualified_name in SPECIAL_CASED_LITERALS:
return SPECIAL_CASED_LITERALS[qualified_name](obj)

if hasattr(obj, "__arrow_c_array__"):
return obj

import pyarrow as pa

try:
return pa.array([obj])
except Exception as e:
raise ValueError(
f"Can't create SedonaDB literal from object of type {qualified_name}"
) from e


def _lit_from_geoarrow_scalar(obj):
wkb_value = None if obj.value is None else obj.wkb
return _lit_from_wkb_and_crs(wkb_value, obj.type.crs)


def _lit_from_dataframe(obj):
if obj.shape != (1, 1):
raise ValueError(
"Can't create SedonaDB literal from DataFrame with shape != (1, 1)"
)

return _resolve_arrow_lit(obj.iloc[0])


def _lit_from_series(obj):
if len(obj) != 1:
raise ValueError("Can't create SedonaDB literal from Series with length != 1")

# A column with dtype "geometry" is not always a GeoSeries; however, if the dtype
# is geometry, obj.array.crs should still be available to extract the CRS.
if obj.dtype.name == "geometry":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also check crs is defined here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if dtype.name == "geometry" implies that the series is a GeoPandas GeoSeries, and if can assume that crs is defined in this case. Besides, the PR LGTM.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is because of a fun corner case in Pandas land: geo_df.iloc[0] is a Series with dtype geometry, not a GeoSeries (hence checking obj.array.crs instead of obj.crs).

first_value = obj.array[0]
first_wkb = None if first_value is None else first_value.wkb
return _lit_from_wkb_and_crs(first_wkb, obj.array.crs)
else:
import pyarrow as pa

return pa.array(obj)


def _lit_from_sedonadb(obj):
if len(obj.columns) != 1:
raise ValueError(
"Can't create SedonaDB literal from SedonaDB DataFrame with number of columns != 1"
)

tab = obj.limit(2).to_arrow_table()
if len(tab) != 1:
raise ValueError(
"Can't create SedonaDB literal from SedonaDB DataFrame with size != 1 row"
)

return tab[0].chunk(0)


def _lit_from_shapely(obj):
return _lit_from_wkb_and_crs(obj.wkb, None)


def _lit_from_wkb_and_crs(wkb, crs):
import pyarrow as pa
import geoarrow.pyarrow as ga

type = ga.wkb().with_crs(crs)
storage = pa.array([wkb], type.storage_type)
return type.wrap_array(storage)


def _qualified_type_name(obj):
return f"{type(obj).__module__}.{type(obj).__name__}"


SPECIAL_CASED_LITERALS = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work for subclass if we match the class names?

and missing LinearRing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on LinearRing!

I use the class name approach (which doesn't catch subclasses) because in order to do isinstance(obj, shapely.Geometry) we need to import shapely, which is something I'd rather not do at the module level if it can be avoided. If something comes up where there's no alternative we could require certain dependencies for using parameterized queries.

"geopandas.geodataframe.GeoDataFrame": _lit_from_dataframe,
"geopandas.geoseries.GeoSeries": _lit_from_series,
# pandas < 3.0
"pandas.core.frame.DataFrame": _lit_from_dataframe,
# pandas >= 3.0
"pandas.DataFrame": _lit_from_dataframe,
"pandas.Series": _lit_from_series,
"sedonadb.dataframe.DataFrame": _lit_from_sedonadb,
"shapely.geometry.point.Point": _lit_from_shapely,
"shapely.geometry.linestring.LineString": _lit_from_shapely,
"shapely.geometry.polygon.Polygon": _lit_from_shapely,
"shapely.geometry.polygon.LinearRing": _lit_from_shapely,
"shapely.geometry.multipoint.MultiPoint": _lit_from_shapely,
"shapely.geometry.multilinestring.MultiLineString": _lit_from_shapely,
"shapely.geometry.multipolygon.MultiPolygon": _lit_from_shapely,
"shapely.geometry.collection.GeometryCollection": _lit_from_shapely,
"geoarrow.pyarrow._scalar.WkbScalar": _lit_from_geoarrow_scalar,
}
Loading