diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 2dd9019f..6171ac5f 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -5,7 +5,7 @@ from collections import OrderedDict from copy import deepcopy from pathlib import Path -from typing import Any +from typing import Any, Literal import matplotlib.pyplot as plt import numpy as np @@ -170,6 +170,7 @@ def render_shapes( method: str | None = None, table_name: str | None = None, table_layer: str | None = None, + shape: Literal["circle", "hex", "square"] | None = None, **kwargs: Any, ) -> sd.SpatialData: """ @@ -232,6 +233,9 @@ def render_shapes( table_layer: str | None Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in :attr:`sdata.table.X` is used for coloring. + shape: Literal["circle", "hex", "square"] | None + If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is + specified, the shapes are converted to a circle/hexagon/square before rendering. **kwargs : Any Additional arguments for customization. This can include: @@ -276,6 +280,7 @@ def render_shapes( scale=scale, table_name=table_name, table_layer=table_layer, + shape=shape, method=method, ds_reduction=kwargs.get("datashader_reduction"), ) @@ -304,6 +309,7 @@ def render_shapes( transfunc=kwargs.get("transfunc"), table_name=param_values["table_name"], table_layer=param_values["table_layer"], + shape=param_values["shape"], zorder=n_steps, method=param_values["method"], ds_reduction=param_values["ds_reduction"], diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 52297005..f280da3f 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -35,6 +35,7 @@ ) from spatialdata_plot.pl.utils import ( _ax_show_and_transform, + _convert_shapes, _create_image_from_datashader_result, _datashader_aggregate_with_function, _datashader_map_aggregate_to_color, @@ -160,6 +161,15 @@ def _render_shapes( trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system) shapes = gpd.GeoDataFrame(shapes, geometry="geometry") + # convert shapes if necessary + if render_params.shape is not None: + current_type = shapes["geometry"].type + if not (render_params.shape == "circle" and (current_type == "Point").all()): + logger.info(f"Converting {shapes.shape[0]} shapes to {render_params.shape}.") + max_extent = np.max( + [shapes.total_bounds[2] - shapes.total_bounds[0], shapes.total_bounds[3] - shapes.total_bounds[1]] + ) + shapes = _convert_shapes(shapes, render_params.shape, max_extent) # Determine which method to use for rendering method = render_params.method @@ -183,17 +193,15 @@ def _render_shapes( # Handle circles encoded as points with radius if is_point.any(): scale = shapes[is_point]["radius"] * render_params.scale - sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) + shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) # apply transformations to the individual points element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system) tm = _get_transformation_matrix_for_datashader(element_trans) - transformed_element = sdata_filt.shapes[element].transform( - lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2] - ) + transformed_element = shapes.transform(lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]) transformed_element = ShapesModel.parse( gpd.GeoDataFrame( - data=sdata_filt.shapes[element].drop("geometry", axis=1), + data=shapes.drop("geometry", axis=1), geometry=transformed_element, ) ) diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index b44175c3..d446a5a5 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -90,6 +90,7 @@ class ShapesRenderParams: zorder: int = 0 table_name: str | None = None table_layer: str | None = None + shape: Literal["circle", "hex", "square"] | None = None ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a2e8f767..90d202c4 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import os import warnings from collections import OrderedDict @@ -51,6 +52,7 @@ from scanpy.plotting._tools.scatterplots import _add_categorical_legend from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation from scanpy.plotting.palettes import default_20, default_28, default_102 +from scipy.spatial import ConvexHull from skimage.color import label2rgb from skimage.morphology import erosion, square from skimage.segmentation import find_boundaries @@ -1709,6 +1711,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if size < 0: raise ValueError("Parameter 'size' must be a positive number.") + if element_type == "shapes" and (shape := param_dict.get("shape")) is not None: + if not isinstance(shape, str): + raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.") + if shape not in ["circle", "hex", "square"]: + raise ValueError( + f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']." + ) + table_name = param_dict.get("table_name") table_layer = param_dict.get("table_layer") if table_name and not isinstance(param_dict["table_name"], str): @@ -1920,6 +1930,7 @@ def _validate_shape_render_params( scale: float | int, table_name: str | None, table_layer: str | None, + shape: Literal["circle", "hex", "square"] | None, method: str | None, ds_reduction: str | None, ) -> dict[str, dict[str, Any]]: @@ -1939,6 +1950,7 @@ def _validate_shape_render_params( "scale": scale, "table_name": table_name, "table_layer": table_layer, + "shape": shape, "method": method, "ds_reduction": ds_reduction, } @@ -1959,6 +1971,7 @@ def _validate_shape_render_params( element_params[el]["norm"] = param_dict["norm"] element_params[el]["scale"] = param_dict["scale"] element_params[el]["table_layer"] = param_dict["table_layer"] + element_params[el]["shape"] = param_dict["shape"] element_params[el]["color"] = param_dict["color"] @@ -2086,7 +2099,7 @@ def _validate_image_render_params( def _get_wanted_render_elements( sdata: SpatialData, sdata_wanted_elements: list[str], - params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams), + params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, cs: str, element_type: Literal["images", "labels", "points", "shapes"], ) -> tuple[list[str], list[str], bool]: @@ -2243,7 +2256,7 @@ def _create_image_from_datashader_result( def _datashader_aggregate_with_function( - reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None), + reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, cvs: Canvas, spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame, col_for_color: str | None, @@ -2307,7 +2320,7 @@ def _datashader_aggregate_with_function( def _datshader_get_how_kw_for_spread( - reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None), + reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, ) -> str: # Get the best input for the how argument of ds.tf.spread(), needed for numerical values reduction = reduction or "sum" @@ -2350,15 +2363,15 @@ def _prepare_transformation( def _get_datashader_trans_matrix_of_single_element( trans: Identity | Scale | Affine | MapAxis | Translation, -) -> npt.NDArray[Any]: +) -> ArrayLike: flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) - tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y")) + tm: ArrayLike = trans.to_affine_matrix(("x", "y"), ("x", "y")) if isinstance(trans, Identity): return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) if isinstance(trans, (Scale | Affine)): # idea: "flip the y-axis", apply transformation, flip back - flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix + flip_and_transform: ArrayLike = flip_matrix @ tm @ flip_matrix return flip_and_transform if isinstance(trans, MapAxis): # no flipping needed @@ -2369,7 +2382,7 @@ def _get_datashader_trans_matrix_of_single_element( def _get_transformation_matrix_for_datashader( trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence, -) -> npt.NDArray[Any]: +) -> ArrayLike: """Get the affine matrix needed to transform shapes for rendering with datashader.""" if isinstance(trans, SDSequence): tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) @@ -2478,3 +2491,121 @@ def _hex_no_alpha(hex: str) -> str: return "#" + hex_digits[:6] raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'") + + +def _convert_shapes( + shapes: GeoDataFrame, target_shape: str, max_extent: float, warn_above_extent_fraction: float = 0.5 +) -> GeoDataFrame: + """Convert the shapes stored in a GeoDataFrame (geometry column) to the target_shape.""" + # NOTE: possible follow-up: when converting equally sized shapes to hex, automatically scale resulting hexagons + # so that they are perfectly adjacent to each other + + if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0: + warn_above_extent_fraction = 0.5 # set to default if the value is outside [0, 1] + warn_shape_size = False + + # define individual conversion methods + def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + vertices = [ + (center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle))) + for angle in range(0, 360, 60) + ] + return shapely.Polygon(vertices), None + + def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + vertices = [ + (center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle))) + for angle in range(45, 360, 90) + ] + return shapely.Polygon(vertices), None + + def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]: + return center, radius + + def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + center, radius = _polygon_to_circle(polygon) + return _circle_to_hexagon(center, radius) + + def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + center, radius = _polygon_to_circle(polygon) + return _circle_to_square(center, radius) + + def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]: + coords = np.array(polygon.exterior.coords) + circle_points = coords[ConvexHull(coords).vertices] + center = np.mean(circle_points, axis=0) + radius = max(float(np.linalg.norm(p - center)) for p in circle_points) + assert isinstance(radius, float) # shut up mypy + if 2 * radius > max_extent * warn_above_extent_fraction: + nonlocal warn_shape_size + warn_shape_size = True + return shapely.Point(center), radius + + def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + center, radius = _multipolygon_to_circle(multipolygon) + return _circle_to_hexagon(center, radius) + + def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + center, radius = _multipolygon_to_circle(multipolygon) + return _circle_to_square(center, radius) + + def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]: + coords = [] + for polygon in multipolygon.geoms: + coords.extend(polygon.exterior.coords) + points = np.array(coords) + circle_points = points[ConvexHull(points).vertices] + center = np.mean(circle_points, axis=0) + radius = max(float(np.linalg.norm(p - center)) for p in circle_points) + assert isinstance(radius, float) # shut up mypy + if 2 * radius > max_extent * warn_above_extent_fraction: + nonlocal warn_shape_size + warn_shape_size = True + return shapely.Point(center), radius + + # define dict with all conversion methods + if target_shape == "circle": + conversion_methods = { + "Point": _circle_to_circle, + "Polygon": _polygon_to_circle, + "Multipolygon": _multipolygon_to_circle, + } + pass + elif target_shape == "hex": + conversion_methods = { + "Point": _circle_to_hexagon, + "Polygon": _polygon_to_hexagon, + "Multipolygon": _multipolygon_to_hexagon, + } + else: + conversion_methods = { + "Point": _circle_to_square, + "Polygon": _polygon_to_square, + "Multipolygon": _multipolygon_to_square, + } + + # convert every shape + for i in range(shapes.shape[0]): + if shapes["geometry"][i].type == "Point": + converted, radius = conversion_methods["Point"](shapes["geometry"][i], shapes["radius"][i]) # type: ignore + elif shapes["geometry"][i].type == "Polygon": + converted, radius = conversion_methods["Polygon"](shapes["geometry"][i]) # type: ignore + elif shapes["geometry"][i].type == "MultiPolygon": + converted, radius = conversion_methods["Multipolygon"](shapes["geometry"][i]) # type: ignore + else: + error_type = shapes["geometry"][i].type + raise ValueError(f"Converting shape {error_type} to {target_shape} is not supported.") + shapes["geometry"][i] = converted + if radius is not None: + if "radius" not in shapes.columns: + shapes["radius"] = np.nan + shapes["radius"][i] = radius + + if warn_shape_size: + logger.info( + f"When converting the shapes, the size of at least one target shape extends " + f"{warn_above_extent_fraction * 100}% of the original total bound of the shapes. The conversion" + " might not give satisfying results in this scenario." + ) + + return shapes diff --git a/tests/_images/Shapes_can_render_circles_to_hex.png b/tests/_images/Shapes_can_render_circles_to_hex.png new file mode 100644 index 00000000..026fdd1e Binary files /dev/null and b/tests/_images/Shapes_can_render_circles_to_hex.png differ diff --git a/tests/_images/Shapes_can_render_circles_to_square.png b/tests/_images/Shapes_can_render_circles_to_square.png new file mode 100644 index 00000000..13003c8a Binary files /dev/null and b/tests/_images/Shapes_can_render_circles_to_square.png differ diff --git a/tests/_images/Shapes_can_render_multipolygons_to_circle.png b/tests/_images/Shapes_can_render_multipolygons_to_circle.png new file mode 100644 index 00000000..fe5a17e7 Binary files /dev/null and b/tests/_images/Shapes_can_render_multipolygons_to_circle.png differ diff --git a/tests/_images/Shapes_can_render_multipolygons_to_hex.png b/tests/_images/Shapes_can_render_multipolygons_to_hex.png new file mode 100644 index 00000000..e5ac72dc Binary files /dev/null and b/tests/_images/Shapes_can_render_multipolygons_to_hex.png differ diff --git a/tests/_images/Shapes_can_render_multipolygons_to_square.png b/tests/_images/Shapes_can_render_multipolygons_to_square.png new file mode 100644 index 00000000..0646e548 Binary files /dev/null and b/tests/_images/Shapes_can_render_multipolygons_to_square.png differ diff --git a/tests/_images/Shapes_can_render_polygons_to_circle.png b/tests/_images/Shapes_can_render_polygons_to_circle.png new file mode 100644 index 00000000..fc3e3906 Binary files /dev/null and b/tests/_images/Shapes_can_render_polygons_to_circle.png differ diff --git a/tests/_images/Shapes_can_render_polygons_to_hex.png b/tests/_images/Shapes_can_render_polygons_to_hex.png new file mode 100644 index 00000000..45d3be26 Binary files /dev/null and b/tests/_images/Shapes_can_render_polygons_to_hex.png differ diff --git a/tests/_images/Shapes_can_render_polygons_to_square.png b/tests/_images/Shapes_can_render_polygons_to_square.png new file mode 100644 index 00000000..02bf5e03 Binary files /dev/null and b/tests/_images/Shapes_can_render_polygons_to_square.png differ diff --git a/tests/_images/Shapes_datashader_can_render_circles_to_hex.png b/tests/_images/Shapes_datashader_can_render_circles_to_hex.png new file mode 100644 index 00000000..d6aebef8 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_circles_to_hex.png differ diff --git a/tests/_images/Shapes_datashader_can_render_circles_to_square.png b/tests/_images/Shapes_datashader_can_render_circles_to_square.png new file mode 100644 index 00000000..c5776d4a Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_circles_to_square.png differ diff --git a/tests/_images/Shapes_datashader_can_render_multipolygons_to_circle.png b/tests/_images/Shapes_datashader_can_render_multipolygons_to_circle.png new file mode 100644 index 00000000..ee543dc1 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_multipolygons_to_circle.png differ diff --git a/tests/_images/Shapes_datashader_can_render_multipolygons_to_hex.png b/tests/_images/Shapes_datashader_can_render_multipolygons_to_hex.png new file mode 100644 index 00000000..7028fc4c Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_multipolygons_to_hex.png differ diff --git a/tests/_images/Shapes_datashader_can_render_multipolygons_to_square.png b/tests/_images/Shapes_datashader_can_render_multipolygons_to_square.png new file mode 100644 index 00000000..90701900 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_multipolygons_to_square.png differ diff --git a/tests/_images/Shapes_datashader_can_render_polygons_to_circle.png b/tests/_images/Shapes_datashader_can_render_polygons_to_circle.png new file mode 100644 index 00000000..01f93369 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_polygons_to_circle.png differ diff --git a/tests/_images/Shapes_datashader_can_render_polygons_to_hex.png b/tests/_images/Shapes_datashader_can_render_polygons_to_hex.png new file mode 100644 index 00000000..8f460b6c Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_polygons_to_hex.png differ diff --git a/tests/_images/Shapes_datashader_can_render_polygons_to_square.png b/tests/_images/Shapes_datashader_can_render_polygons_to_square.png new file mode 100644 index 00000000..2ae09482 Binary files /dev/null and b/tests/_images/Shapes_datashader_can_render_polygons_to_square.png differ diff --git a/tests/_images/Shapes_datashader_can_transform_circles.png b/tests/_images/Shapes_datashader_can_transform_circles.png index 60cde073..a67f9daa 100644 Binary files a/tests/_images/Shapes_datashader_can_transform_circles.png and b/tests/_images/Shapes_datashader_can_transform_circles.png differ diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 953eb843..9bcc01ab 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -562,3 +562,51 @@ def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialDat sdata_blobs["circle_table"].layers["normalized"] = RNG.random((nrows, ncols)) sdata_blobs.pl.render_shapes("blobs_circles", color="feature0", table_layer="normalized").pl.show() + + def test_plot_can_render_circles_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex").pl.show() + + def test_plot_can_render_circles_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", shape="square").pl.show() + + def test_plot_can_render_polygons_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="hex").pl.show() + + def test_plot_can_render_polygons_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="square").pl.show() + + def test_plot_can_render_polygons_to_circle(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="circle").pl.show() + + def test_plot_can_render_multipolygons_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="hex").pl.show() + + def test_plot_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="square").pl.show() + + def test_plot_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle").pl.show() + + def test_plot_datashader_can_render_circles_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex", method="datashader").pl.show() + + def test_plot_datashader_can_render_circles_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_circles", shape="square", method="datashader").pl.show() + + def test_plot_datashader_can_render_polygons_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="hex", method="datashader").pl.show() + + def test_plot_datashader_can_render_polygons_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="square", method="datashader").pl.show() + + def test_plot_datashader_can_render_polygons_to_circle(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_polygons", shape="circle", method="datashader").pl.show() + + def test_plot_datashader_can_render_multipolygons_to_hex(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="hex", method="datashader").pl.show() + + def test_plot_datashader_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="square", method="datashader").pl.show() + + def test_plot_datashader_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle", method="datashader").pl.show()