Description
When trying to plot a categorical column (say key
) of sdata.tables[table_name].obs
using spatialdata-plot, colors specified in sdata.tables[table_name].uns[ "{key}_colors" ]
are ignored due to this line:
This PR fixes this;
#413
Minimal example to reproduce the issue:
import numpy as np
import pandas as pd
from anndata import AnnData
from spatialdata import get_element_instances
from spatialdata.datasets import blobs
from spatialdata.models import TableModel
import spatialdata_plot
RNG = np.random.default_rng(seed=42)
labels_name = "blobs_labels"
sdata_blobs = blobs()
instances = get_element_instances(sdata_blobs[labels_name])
n_obs = len(instances)
adata = AnnData(
RNG.normal(size=(n_obs, 10)),
obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"]),
)
adata.obs["instance_id"] = instances.values
adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs)
adata.obs["category"][:3] = ["a", "b", "c"]
adata.obs["region"] = labels_name
table = TableModel.parse(
adata=adata,
region_key="region",
instance_key="instance_id",
region=labels_name,
)
sdata_blobs["other_table"] = table
sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category")
sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"] #purple, green ,yellow
# placeholder, otherwise "category_colors" will be ignored
sdata_blobs["other_table"].uns["category"] = "__value__"
sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()
With the small fix this gives
I've added a unit test which reproduces the issue in the PR https://github.com/ArneDefauw/spatialdata-plot/blob/5af65aa118f7abf87e47470038ecdbddb27ef1ca/tests/pl/test_render_labels.py#L217
As explained in the PR, I've stumbled upon an issue when trying to plot a subset of the data, trying to maintain the same colors, see https://github.com/ArneDefauw/spatialdata-plot/blob/5af65aa118f7abf87e47470038ecdbddb27ef1ca/tests/pl/test_render_labels.py#L225.
Starting from the code above, if we do:
sdata_blobs = bounding_box_query(
sdata_blobs,
axes=("y", "x"),
min_coordinate=[0, 0],
max_coordinate=[100, 100],
target_coordinate_system="global",
)
sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show()
we get:
while we expect
This issue is caused by https://github.com/scverse/spatialdata/blob/03d3be80fad69ff54097e90a9e80ad02e9e0e242/src/spatialdata/_utils.py#L203.
There is a workaround for this:
sdata_blobs["other_table"].obs["category"] = (
sdata_blobs["other_table"].obs["category"].cat.remove_unused_categories()
)
as documented here https://github.com/ArneDefauw/spatialdata-plot/blob/5af65aa118f7abf87e47470038ecdbddb27ef1ca/tests/pl/test_render_labels.py#L250, but this should probably be documented somewhere public.