Skip to content

Commit

Permalink
[SEDONA-714] Add geopandas to spark arrow conversion. (#1825)
Browse files Browse the repository at this point in the history
* SEDONA-714 Add geopandas to spark arrow conversion.

* SEDONA-714 Add geopandas to spark arrow conversion.

* SEDONA-714 Add geopandas to spark arrow conversion.

* SEDONA-714 Add geopandas to spark arrow conversion.

* SEDONA-714 Add geopandas to spark arrow conversion.

* Update python/sedona/utils/geoarrow.py

Co-authored-by: Dewey Dunnington <[email protected]>

* SEDONA-714 Add geopandas to spark arrow conversion.

* SEDONA-714 Add docs.

* SEDONA-714 Add docs.

---------

Co-authored-by: Dewey Dunnington <[email protected]>
  • Loading branch information
Imbruced and paleolimbot authored Feb 26, 2025
1 parent 5007131 commit ebd6f67
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ clean:
rm -rf __pycache__
rm -rf .mypy_cache
rm -rf .pytest_cache

run-docs:
docker build -f docker/docs/Dockerfile -t mkdocs-sedona .
docker run --rm -it -p 8000:8000 -v ${PWD}:/docs mkdocs-sedona
8 changes: 8 additions & 0 deletions docker/docs/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FROM squidfunk/mkdocs-material:9.6

RUN apk update
RUN apk add gcc musl-dev linux-headers
RUN pip install mkdocs-macros-plugin \
mkdocs-git-revision-date-localized-plugin \
mkdocs-jupyter \
mike
19 changes: 19 additions & 0 deletions docs/tutorial/geopandas-shapely.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ This query will show the following outputs:
```

To leverage Arrow optimization and speed up the conversion, you can use the `create_spatial_dataframe`
that takes a SparkSession and GeoDataFrame as parameters and returns a Sedona DataFrame.

```python
def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> DataFrame
```

- spark: SparkSession
- gdf: gpd.GeoDataFrame
- return: DataFrame

Example:

```python
from sedona.utils.geoarrow import create_spatial_dataframe

create_spatial_dataframe(spark, gdf)
```

### From Sedona DataFrame to GeoPandas

Reading data with Spark and converting to GeoPandas
Expand Down
139 changes: 138 additions & 1 deletion python/sedona/utils/geoarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import itertools
from typing import List, Callable

# We may be able to achieve streaming rather than complete materialization by using
# with the ArrowStreamSerializer (instead of the ArrowCollectSerializer)


from sedona.sql.types import GeometryType
from sedona.sql.st_functions import ST_AsEWKB
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, StructField, DataType, ArrayType, MapType

from sedona.sql.types import GeometryType
import geopandas as gpd
from pyspark.sql.pandas.types import (
from_arrow_type,
)
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer


def dataframe_to_arrow(df, crs=None):
Expand Down Expand Up @@ -186,3 +197,129 @@ def unique_srid_from_ewkb(obj):
import pyproj

return pyproj.CRS(f"EPSG:{epsg_code}")


def _dedup_names(names: List[str]) -> List[str]:
if len(set(names)) == len(names):
return names
else:

def _gen_dedup(_name: str) -> Callable[[], str]:
_i = itertools.count()
return lambda: f"{_name}_{next(_i)}"

def _gen_identity(_name: str) -> Callable[[], str]:
return lambda: _name

gen_new_name = {
name: _gen_dedup(name) if len(list(group)) > 1 else _gen_identity(name)
for name, group in itertools.groupby(sorted(names))
}
return [gen_new_name[name]() for name in names]


# Backport from Spark 4.0
# https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/types.py#L1385
def _deduplicate_field_names(dt: DataType) -> DataType:
if isinstance(dt, StructType):
dedup_field_names = _dedup_names(dt.names)

return StructType(
[
StructField(
dedup_field_names[i],
_deduplicate_field_names(field.dataType),
nullable=field.nullable,
)
for i, field in enumerate(dt.fields)
]
)
elif isinstance(dt, ArrayType):
return ArrayType(
_deduplicate_field_names(dt.elementType), containsNull=dt.containsNull
)
elif isinstance(dt, MapType):
return MapType(
_deduplicate_field_names(dt.keyType),
_deduplicate_field_names(dt.valueType),
valueContainsNull=dt.valueContainsNull,
)
else:
return dt


def infer_schema(gdf: gpd.GeoDataFrame) -> StructType:
import pyarrow as pa

fields = gdf.dtypes.reset_index().values.tolist()
geom_fields = []
index = 0
for name, dtype in fields:
if dtype == "geometry":
geom_fields.append((index, name))
continue

index += 1

if not geom_fields:
raise ValueError("No geometry field found in the GeoDataFrame")

pa_schema = pa.Schema.from_pandas(
gdf.drop([name for _, name in geom_fields], axis=1)
)

spark_schema = []

for field in pa_schema:
field_type = field.type
spark_type = from_arrow_type(field_type)
spark_schema.append(StructField(field.name, spark_type, True))

for index, geom_field in geom_fields:
spark_schema.insert(index, StructField(geom_field, GeometryType(), True))

return StructType(spark_schema)


# Modified backport from Spark 4.0
# https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/conversion.py#L632
def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> DataFrame:
from pyspark.sql.pandas.types import (
to_arrow_type,
)

def reader_func(temp_filename):
return spark._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)

def create_iter_server():
return spark._jvm.ArrowIteratorServer()

schema = infer_schema(gdf)
timezone = spark._jconf.sessionLocalTimeZone()
step = spark._jconf.arrowMaxRecordsPerBatch()
step = step if step > 0 else len(gdf)
pdf_slices = (gdf.iloc[start : start + step] for start in range(0, len(gdf), step))
spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields]

arrow_data = [
[
(c, to_arrow_type(t) if t is not None else None, t)
for (_, c), t in zip(pdf_slice.items(), spark_types)
]
for pdf_slice in pdf_slices
]

safecheck = spark._jconf.arrowSafeTypeConversion()
ser = ArrowStreamPandasSerializer(timezone, safecheck)
jiter = spark._sc._serialize_to_jvm(
arrow_data, ser, reader_func, create_iter_server
)

jsparkSession = spark._jsparkSession
jdf = spark._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession)

df = DataFrame(jdf, spark)

df._schema = schema

return df
94 changes: 94 additions & 0 deletions python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest

from sedona.sql.types import GeometryType
from sedona.utils.geoarrow import create_spatial_dataframe
from tests.test_base import TestBase
import geopandas as gpd
import pyspark


class TestGeopandasToSedonaWithArrow(TestBase):

@pytest.mark.skipif(
not pyspark.__version__.startswith("3.5"),
reason="It's only working with Spark 3.5",
)
def test_conversion_dataframe(self):
gdf = gpd.GeoDataFrame(
{
"name": ["Sedona", "Apache"],
"geometry": gpd.points_from_xy([0, 1], [0, 1]),
}
)

df = create_spatial_dataframe(self.spark, gdf)

assert df.count() == 2
assert df.columns == ["name", "geometry"]
assert df.schema["geometry"].dataType == GeometryType()

@pytest.mark.skipif(
not pyspark.__version__.startswith("3.5"),
reason="It's only working with Spark 3.5",
)
def test_different_geometry_positions(self):
gdf = gpd.GeoDataFrame(
{
"geometry": gpd.points_from_xy([0, 1], [0, 1]),
"name": ["Sedona", "Apache"],
}
)

gdf2 = gpd.GeoDataFrame(
{
"name": ["Sedona", "Apache"],
"name1": ["Sedona", "Apache"],
"name2": ["Sedona", "Apache"],
"geometry": gpd.points_from_xy([0, 1], [0, 1]),
}
)

df1 = create_spatial_dataframe(self.spark, gdf)
df2 = create_spatial_dataframe(self.spark, gdf2)

assert df1.count() == 2
assert df1.columns == ["geometry", "name"]
assert df1.schema["geometry"].dataType == GeometryType()

assert df2.count() == 2
assert df2.columns == ["name", "name1", "name2", "geometry"]
assert df2.schema["geometry"].dataType == GeometryType()

@pytest.mark.skipif(
not pyspark.__version__.startswith("3.5"),
reason="It's only working with Spark 3.5",
)
def test_multiple_geometry_columns(self):
gdf = gpd.GeoDataFrame(
{
"name": ["Sedona", "Apache"],
"geometry": gpd.points_from_xy([0, 1], [0, 1]),
"geometry2": gpd.points_from_xy([0, 1], [0, 1]),
}
)

df = create_spatial_dataframe(self.spark, gdf)

assert df.count() == 2
assert df.columns == ["name", "geometry2", "geometry"]
assert df.schema["geometry"].dataType == GeometryType()
assert df.schema["geometry2"].dataType == GeometryType()

@pytest.mark.skipif(
not pyspark.__version__.startswith("3.5"),
reason="It's only working with Spark 3.5",
)
def test_missing_geometry_column(self):
gdf = gpd.GeoDataFrame(
{
"name": ["Sedona", "Apache"],
},
)

with pytest.raises(ValueError):
create_spatial_dataframe(self.spark, gdf)

0 comments on commit ebd6f67

Please sign in to comment.