diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index f126c5c0..cca98205 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -338,7 +338,7 @@ def _render_shapes( cax = None if aggregate_with_reduction is not None: vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin - vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax + vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: assert norm.vmin is not None assert norm.vmax is not None @@ -850,20 +850,22 @@ def _render_images( # 2) Image has any number of channels but 1 else: layers = {} - for ch_index, c in enumerate(channels): - layers[c] = img.sel(c=c).copy(deep=True).squeeze() - - if not isinstance(render_params.cmap_params, list): - if render_params.cmap_params.norm is not None: - layers[c] = render_params.cmap_params.norm(layers[c]) + for ch_idx, ch in enumerate(channels): + layers[ch] = img.sel(c=ch).copy(deep=True).squeeze() + if isinstance(render_params.cmap_params, list): + ch_norm = render_params.cmap_params[ch_idx].norm + ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default else: - if render_params.cmap_params[ch_index].norm is not None: - layers[c] = render_params.cmap_params[ch_index].norm(layers[c]) + ch_norm = render_params.cmap_params.norm + ch_cmap_is_default = render_params.cmap_params.cmap_is_default + + if not ch_cmap_is_default and ch_norm is not None: + layers[ch_idx] = ch_norm(layers[ch_idx]) # 2A) Image has 3 channels, no palette info, and no/only one cmap was given if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list): if render_params.cmap_params.cmap_is_default: # -> use RGB - stacked = np.stack([layers[c] for c in channels], axis=-1) + stacked = np.stack([layers[ch] for ch in layers], axis=-1) else: # -> use given cmap for each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels stacked = ( @@ -896,12 +898,54 @@ def _render_images( # overwrite if n_channels == 2 for intuitive result if n_channels == 2: seed_colors = ["#ff0000ff", "#00ff00ff"] - else: + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + colored = np.stack( + [channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)], + 0, + ).sum(0) + colored = colored[:, :, :3] + elif n_channels == 3: seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + colored = np.stack( + [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], + 0, + ).sum(0) + colored = colored[:, :, :3] + else: + if isinstance(render_params.cmap_params, list): + cmap_is_default = render_params.cmap_params[0].cmap_is_default + else: + cmap_is_default = render_params.cmap_params.cmap_is_default - channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) - colored = colored[:, :, :3] + if cmap_is_default: + seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) + else: + # Sample n_channels colors evenly from the colormap + if isinstance(render_params.cmap_params, list): + seed_colors = [ + render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels) + ] + else: + seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)] + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + + # Stack (n_channels, height, width) → (height*width, n_channels) + H, W = next(iter(layers.values())).shape + comp_rgb = np.zeros((H, W, 3), dtype=float) + + # For each channel: map to RGBA, apply constant alpha, then add + for ch_idx, ch in enumerate(channels): + layer_arr = layers[ch] + rgba = channel_cmaps[ch_idx](layer_arr) + rgba[..., 3] = render_params.alpha + comp_rgb += rgba[..., :3] * rgba[..., 3][..., None] + + colored = np.clip(comp_rgb, 0, 1) + logger.info( + f"Your image has {n_channels} channels. Sampling categorical colors and using " + f"multichannel strategy 'stack' to render." + ) # TODO: update when pca is added as strategy _ax_show_and_transform( colored, @@ -947,6 +991,7 @@ def _render_images( zorder=render_params.zorder, ) + # 2D) Image has n channels, no palette but cmap info elif palette is not None and got_multiple_cmaps: raise ValueError("If 'palette' is provided, 'cmap' must be None.") diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 4e4f0b5f..c795bbea 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2008,7 +2008,7 @@ def _validate_col_for_column_table( table_name = next(iter(tables)) if len(tables) > 1: warnings.warn( - f"Multiple tables contain color column, using {table_name}", + f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.", UserWarning, stacklevel=2, ) @@ -2044,44 +2044,57 @@ def _validate_image_render_params( element_params[el] = {} spatial_element = param_dict["sdata"][el] + # robustly get channel names from image or multiscale image spatial_element_ch = ( - spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c + spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values ) - channel = param_dict["channel"] - channel_list: list[str] | list[int] | None - if isinstance(channel, list): - type_ = type(channel[0]) - assert all(isinstance(ch, type_) for ch in channel), "All channels must be of the same type." - # mypy complains that channel_list can be also of type list[str | int] - channel_list = [channel] if isinstance(channel, int | str) else channel # type: ignore[assignment] - - if channel_list is not None and ( - (isinstance(channel_list[0], int) and max([abs(ch) for ch in channel_list]) <= len(spatial_element_ch)) # type: ignore[arg-type] - or all(ch in spatial_element_ch for ch in channel_list) - ): - element_params[el]["channel"] = channel_list + if channel is not None: + # Normalize channel to always be a list of str or a list of int + if isinstance(channel, str): + channel = [channel] + + if isinstance(channel, int): + channel = [channel] + + # If channel is a list, ensure all elements are the same type + if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)): + raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.") + + invalid = [c for c in channel if c not in spatial_element_ch] + if invalid: + raise ValueError( + f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}" + ) + element_params[el]["channel"] = channel else: element_params[el]["channel"] = None element_params[el]["alpha"] = param_dict["alpha"] - if isinstance(palette := param_dict["palette"], list): + palette = param_dict["palette"] + assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure + + if isinstance(palette, list): + # case A: single palette for all channels if len(palette) == 1: - palette_length = len(channel_list) if channel_list is not None else len(spatial_element_ch) + palette_length = len(channel) if channel is not None else len(spatial_element_ch) palette = palette * palette_length - if (channel_list is not None and len(palette) != len(channel_list)) and len(palette) != len( - spatial_element_ch - ): - palette = None + # case B: one palette per channel (either given or derived from channel length) + channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel + if channels_to_use is not None and len(palette) != len(channels_to_use): + raise ValueError( + f"Palette length ({len(palette)}) does not match channel length " + f"({', '.join(str(c) for c in channels_to_use)})." + ) element_params[el]["palette"] = palette element_params[el]["na_color"] = param_dict["na_color"] if (cmap := param_dict["cmap"]) is not None: if len(cmap) == 1: - cmap_length = len(channel_list) if channel_list is not None else len(spatial_element_ch) + cmap_length = len(channel) if channel is not None else len(spatial_element_ch) cmap = cmap * cmap_length - if (channel_list is not None and len(cmap) != len(channel_list)) or len(cmap) != len(spatial_element_ch): + if (channel is not None and len(cmap) != len(channel)) or len(cmap) != len(spatial_element_ch): cmap = None element_params[el]["cmap"] = cmap element_params[el]["norm"] = param_dict["norm"] @@ -2099,7 +2112,7 @@ def _validate_image_render_params( def _get_wanted_render_elements( sdata: SpatialData, sdata_wanted_elements: list[str], - params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams), + params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, cs: str, element_type: Literal["images", "labels", "points", "shapes"], ) -> tuple[list[str], list[str], bool]: @@ -2256,7 +2269,7 @@ def _create_image_from_datashader_result( def _datashader_aggregate_with_function( - reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None), + reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, cvs: Canvas, spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame, col_for_color: str | None, @@ -2320,7 +2333,7 @@ def _datashader_aggregate_with_function( def _datshader_get_how_kw_for_spread( - reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None), + reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, ) -> str: # Get the best input for the how argument of ds.tf.spread(), needed for numerical values reduction = reduction or "sum" diff --git a/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png b/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png index c22b9f2b..16bedd33 100644 Binary files a/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png and b/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png differ