Skip to content

add shapes parameter to render shapes as hex/circle/square #474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import OrderedDict
from copy import deepcopy
from pathlib import Path
from typing import Any
from typing import Any, Literal

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -170,6 +170,7 @@ def render_shapes(
method: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
shape: Literal["circle", "hex", "square"] | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -232,6 +233,9 @@ def render_shapes(
table_layer: str | None
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
:attr:`sdata.table.X` is used for coloring.
shape: Literal["circle", "hex", "square"] | None
If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is
specified, the shapes are converted to a circle/hexagon/square before rendering.

**kwargs : Any
Additional arguments for customization. This can include:
Expand Down Expand Up @@ -276,6 +280,7 @@ def render_shapes(
scale=scale,
table_name=table_name,
table_layer=table_layer,
shape=shape,
method=method,
ds_reduction=kwargs.get("datashader_reduction"),
)
Expand Down Expand Up @@ -304,6 +309,7 @@ def render_shapes(
transfunc=kwargs.get("transfunc"),
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
shape=param_values["shape"],
zorder=n_steps,
method=param_values["method"],
ds_reduction=param_values["ds_reduction"],
Expand Down
18 changes: 13 additions & 5 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from spatialdata_plot.pl.utils import (
_ax_show_and_transform,
_convert_shapes,
_create_image_from_datashader_result,
_datashader_aggregate_with_function,
_datashader_map_aggregate_to_color,
Expand Down Expand Up @@ -160,6 +161,15 @@ def _render_shapes(
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)

shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
# convert shapes if necessary
if render_params.shape is not None:
current_type = shapes["geometry"].type
if not (render_params.shape == "circle" and (current_type == "Point").all()):
logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.")
max_extent = np.max(
[shapes.total_bounds[2] - shapes.total_bounds[0], shapes.total_bounds[3] - shapes.total_bounds[1]]
)
shapes = _convert_shapes(shapes, render_params.shape, max_extent)

# Determine which method to use for rendering
method = render_params.method
Expand All @@ -183,17 +193,15 @@ def _render_shapes(
# Handle circles encoded as points with radius
if is_point.any():
scale = shapes[is_point]["radius"] * render_params.scale
sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())

# apply transformations to the individual points
element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system)
tm = _get_transformation_matrix_for_datashader(element_trans)
transformed_element = sdata_filt.shapes[element].transform(
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
)
transformed_element = shapes.transform(lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2])
transformed_element = ShapesModel.parse(
gpd.GeoDataFrame(
data=sdata_filt.shapes[element].drop("geometry", axis=1),
data=shapes.drop("geometry", axis=1),
geometry=transformed_element,
)
)
Expand Down
1 change: 1 addition & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class ShapesRenderParams:
zorder: int = 0
table_name: str | None = None
table_layer: str | None = None
shape: Literal["circle", "hex", "square"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None


Expand Down
145 changes: 138 additions & 7 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
import os
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -51,6 +52,7 @@
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
from scanpy.plotting.palettes import default_20, default_28, default_102
from scipy.spatial import ConvexHull
from skimage.color import label2rgb
from skimage.morphology import erosion, square
from skimage.segmentation import find_boundaries
Expand Down Expand Up @@ -1709,6 +1711,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
if size < 0:
raise ValueError("Parameter 'size' must be a positive number.")

if element_type == "shapes" and (shape := param_dict.get("shape")) is not None:
if not isinstance(shape, str):
raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.")
if shape not in ["circle", "hex", "square"]:
raise ValueError(
f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']."
)

table_name = param_dict.get("table_name")
table_layer = param_dict.get("table_layer")
if table_name and not isinstance(param_dict["table_name"], str):
Expand Down Expand Up @@ -1920,6 +1930,7 @@ def _validate_shape_render_params(
scale: float | int,
table_name: str | None,
table_layer: str | None,
shape: Literal["circle", "hex", "square"] | None,
method: str | None,
ds_reduction: str | None,
) -> dict[str, dict[str, Any]]:
Expand All @@ -1939,6 +1950,7 @@ def _validate_shape_render_params(
"scale": scale,
"table_name": table_name,
"table_layer": table_layer,
"shape": shape,
"method": method,
"ds_reduction": ds_reduction,
}
Expand All @@ -1959,6 +1971,7 @@ def _validate_shape_render_params(
element_params[el]["norm"] = param_dict["norm"]
element_params[el]["scale"] = param_dict["scale"]
element_params[el]["table_layer"] = param_dict["table_layer"]
element_params[el]["shape"] = param_dict["shape"]

element_params[el]["color"] = param_dict["color"]

Expand Down Expand Up @@ -2086,7 +2099,7 @@ def _validate_image_render_params(
def _get_wanted_render_elements(
sdata: SpatialData,
sdata_wanted_elements: list[str],
params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams),
params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
cs: str,
element_type: Literal["images", "labels", "points", "shapes"],
) -> tuple[list[str], list[str], bool]:
Expand Down Expand Up @@ -2243,7 +2256,7 @@ def _create_image_from_datashader_result(


def _datashader_aggregate_with_function(
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
cvs: Canvas,
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
col_for_color: str | None,
Expand Down Expand Up @@ -2307,7 +2320,7 @@ def _datashader_aggregate_with_function(


def _datshader_get_how_kw_for_spread(
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
) -> str:
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
reduction = reduction or "sum"
Expand Down Expand Up @@ -2350,15 +2363,15 @@ def _prepare_transformation(

def _get_datashader_trans_matrix_of_single_element(
trans: Identity | Scale | Affine | MapAxis | Translation,
) -> npt.NDArray[Any]:
) -> ArrayLike:
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y"))
tm: ArrayLike = trans.to_affine_matrix(("x", "y"), ("x", "y"))

if isinstance(trans, Identity):
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
if isinstance(trans, (Scale | Affine)):
# idea: "flip the y-axis", apply transformation, flip back
flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix
flip_and_transform: ArrayLike = flip_matrix @ tm @ flip_matrix
return flip_and_transform
if isinstance(trans, MapAxis):
# no flipping needed
Expand All @@ -2369,7 +2382,7 @@ def _get_datashader_trans_matrix_of_single_element(

def _get_transformation_matrix_for_datashader(
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
) -> npt.NDArray[Any]:
) -> ArrayLike:
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
if isinstance(trans, SDSequence):
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
Expand Down Expand Up @@ -2478,3 +2491,121 @@ def _hex_no_alpha(hex: str) -> str:
return "#" + hex_digits[:6]

raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'")


def _convert_shapes(
shapes: GeoDataFrame, target_shape: str, max_extent: float, warn_above_extent_fraction: float = 0.5
) -> GeoDataFrame:
"""Convert the shapes stored in a GeoDataFrame (geometry column) to the target_shape."""
# NOTE: possible follow-up: when converting equally sized shapes to hex, automatically scale resulting hexagons
# so that they are perfectly adjacent to each other

if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0:
warn_above_extent_fraction = 0.5 # set to default if the value is outside [0, 1]
warn_shape_size = False

# define individual conversion methods
def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
vertices = [
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
for angle in range(0, 360, 60)
]
return shapely.Polygon(vertices), None

def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
vertices = [
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
for angle in range(45, 360, 90)
]
return shapely.Polygon(vertices), None

def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]:
return center, radius

def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
center, radius = _polygon_to_circle(polygon)
return _circle_to_hexagon(center, radius)

def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
center, radius = _polygon_to_circle(polygon)
return _circle_to_square(center, radius)

def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]:
coords = np.array(polygon.exterior.coords)
circle_points = coords[ConvexHull(coords).vertices]
center = np.mean(circle_points, axis=0)
radius = max(float(np.linalg.norm(p - center)) for p in circle_points)
assert isinstance(radius, float) # shut up mypy
if 2 * radius > max_extent * warn_above_extent_fraction:
nonlocal warn_shape_size
warn_shape_size = True
return shapely.Point(center), radius

def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
center, radius = _multipolygon_to_circle(multipolygon)
return _circle_to_hexagon(center, radius)

def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
center, radius = _multipolygon_to_circle(multipolygon)
return _circle_to_square(center, radius)

def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]:
coords = []
for polygon in multipolygon.geoms:
coords.extend(polygon.exterior.coords)
points = np.array(coords)
circle_points = points[ConvexHull(points).vertices]
center = np.mean(circle_points, axis=0)
radius = max(float(np.linalg.norm(p - center)) for p in circle_points)
assert isinstance(radius, float) # shut up mypy
if 2 * radius > max_extent * warn_above_extent_fraction:
nonlocal warn_shape_size
warn_shape_size = True
return shapely.Point(center), radius

# define dict with all conversion methods
if target_shape == "circle":
conversion_methods = {
"Point": _circle_to_circle,
"Polygon": _polygon_to_circle,
"Multipolygon": _multipolygon_to_circle,
}
pass
elif target_shape == "hex":
conversion_methods = {
"Point": _circle_to_hexagon,
"Polygon": _polygon_to_hexagon,
"Multipolygon": _multipolygon_to_hexagon,
}
else:
conversion_methods = {
"Point": _circle_to_square,
"Polygon": _polygon_to_square,
"Multipolygon": _multipolygon_to_square,
}

# convert every shape
for i in range(shapes.shape[0]):
if shapes["geometry"][i].type == "Point":
converted, radius = conversion_methods["Point"](shapes["geometry"][i], shapes["radius"][i]) # type: ignore
elif shapes["geometry"][i].type == "Polygon":
converted, radius = conversion_methods["Polygon"](shapes["geometry"][i]) # type: ignore
elif shapes["geometry"][i].type == "MultiPolygon":
converted, radius = conversion_methods["Multipolygon"](shapes["geometry"][i]) # type: ignore
else:
error_type = shapes["geometry"][i].type
raise ValueError(f"Converting shape {error_type} to {target_shape} is not supported.")
shapes["geometry"][i] = converted
if radius is not None:
if "radius" not in shapes.columns:
shapes["radius"] = np.nan
shapes["radius"][i] = radius

if warn_shape_size:
logger.info(
f"When converting the shapes, the size of at least one target shape extends "
f"{warn_above_extent_fraction * 100}% of the original total bound of the shapes. The conversion"
" might not give satisfying results in this scenario."
)

return shapes
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_datashader_can_transform_circles.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 48 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,51 @@ def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialDat
sdata_blobs["circle_table"].layers["normalized"] = RNG.random((nrows, ncols))

sdata_blobs.pl.render_shapes("blobs_circles", color="feature0", table_layer="normalized").pl.show()

def test_plot_can_render_circles_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex").pl.show()

def test_plot_can_render_circles_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="square").pl.show()

def test_plot_can_render_polygons_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="hex").pl.show()

def test_plot_can_render_polygons_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="square").pl.show()

def test_plot_can_render_polygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="circle").pl.show()

def test_plot_can_render_multipolygons_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="hex").pl.show()

def test_plot_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="square").pl.show()

def test_plot_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle").pl.show()

def test_plot_datashader_can_render_circles_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex", method="datashader").pl.show()

def test_plot_datashader_can_render_circles_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="square", method="datashader").pl.show()

def test_plot_datashader_can_render_polygons_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="hex", method="datashader").pl.show()

def test_plot_datashader_can_render_polygons_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="square", method="datashader").pl.show()

def test_plot_datashader_can_render_polygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="circle", method="datashader").pl.show()

def test_plot_datashader_can_render_multipolygons_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="hex", method="datashader").pl.show()

def test_plot_datashader_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="square", method="datashader").pl.show()

def test_plot_datashader_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle", method="datashader").pl.show()
Loading