Skip to content

Commit

Permalink
Enhancement: KM plot data depth and functionality (#853)
Browse files Browse the repository at this point in the history
* using current commit of tutorials notebook

* Enhance Kaplan-Meier plot with grid layout and survival probability table

* added toggle for table

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix return values in kaplan_meier and update documentation for cox_ph_forestplot

* Fix typo in display_table parameter and add tests for Kaplan-Meier plots

* Rename display_table parameter to display_survival_statistics in kaplan_meier function and update related tests

* Fix documentation typo in cox_ph_forestplot function

* Update grid specification in kaplan_meier function and rerendered test images

* Improve docstring

---------

Co-authored-by: Eljas Roellin <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lukas Heumos <[email protected]>
  • Loading branch information
4 people authored Jan 21, 2025
1 parent d710367 commit ae9eed7
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 10 deletions.
60 changes: 50 additions & 10 deletions ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import warnings
from typing import TYPE_CHECKING

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from matplotlib import gridspec
from numpy import ndarray

from ehrapy.plot import scatter
Expand Down Expand Up @@ -185,6 +185,8 @@ def kmf(

def kaplan_meier(
kmfs: Sequence[KaplanMeierFitter],
*,
display_survival_statistics: bool = False,
ci_alpha: list[float] | None = None,
ci_force_lines: list[Boolean] | None = None,
ci_show: list[Boolean] | None = None,
Expand All @@ -206,6 +208,7 @@ def kaplan_meier(
Args:
kmfs: Iterables of fitted KaplanMeierFitter objects.
display_survival_statistics: Whether to show survival statistics in a table below the plot.
ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list.
ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas).
If more than one kmfs, this should be a list.
Expand Down Expand Up @@ -264,7 +267,10 @@ def kaplan_meier(
at_risk_counts = [False] * len(kmfs)
if color is None:
color = [None] * len(kmfs)
plt.figure(figsize=figsize)

fig = plt.figure(constrained_layout=True, figsize=figsize)
spec = fig.add_gridspec(2, 1) if display_survival_statistics else fig.add_gridspec(1, 1)
ax = plt.subplot(spec[0, 0])

for i, kmf in enumerate(kmfs):
if i == 0:
Expand All @@ -286,16 +292,50 @@ def kaplan_meier(
at_risk_counts=at_risk_counts[i],
color=color[i],
)
# Configure plot appearance
ax.grid(grid)
plt.xlim(xlim)
plt.ylim(ylim)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if title:
plt.title(title)
ax.set_title(title)

if display_survival_statistics:
xticks = [x for x in ax.get_xticks() if x >= 0]
xticks_space = xticks[1] - xticks[0]
if xlabel is None:
xlabel = "Time"

yticks = np.arange(len(kmfs))

ax_table = plt.subplot(spec[1, 0])
ax_table.set_xticks(xticks)
ax_table.set_xlim(-xticks_space / 2, xticks[-1] + xticks_space / 2)
ax_table.set_ylim(-1, len(kmfs))
ax_table.set_yticks(yticks)
ax_table.set_yticklabels([kmf.label if kmf.label else f"Group {i + 1}" for i, kmf in enumerate(kmfs[::-1])])

for i, kmf in enumerate(kmfs[::-1]):
survival_probs = kmf.survival_function_at_times(xticks).values
for j, prob in enumerate(survival_probs):
ax_table.text(
xticks[j], # x position
yticks[i], # y position
f"{prob:.2f}", # formatted survival probability
ha="center",
va="center",
bbox={"boxstyle": "round,pad=0.2", "edgecolor": "none", "facecolor": "lightgrey"},
)

ax_table.grid(grid)
ax_table.spines["top"].set_visible(False)
ax_table.spines["right"].set_visible(False)
ax_table.spines["bottom"].set_visible(False)
ax_table.spines["left"].set_visible(False)

if not show:
return ax
return fig, ax

else:
return None
Expand All @@ -320,13 +360,13 @@ def cox_ph_forestplot(
"""Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model.
The `adata` object must first be populated using the :func:`~ehrapy.tools.cox_ph` function. This function stores the summary table of the `CoxPHFitter` in the `.uns` attribute of `adata`.
The summary table is created when the model is fitted using the :func:`ehrapy.tl.cox_ph` function.
The summary table is created when the model is fitted using the :func:`~ehrapy.tools.cox_ph` function.
For more information on the `CoxPHFitter`, see the `Lifelines documentation <https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html>`_.
Inspired by `zepid.graphics.EffectMeasurePlot <https://readthedocs.org>`_ (zEpid Package, https://pypi.org/project/zepid/).
Args:
adata: :class:`~anndata.AnnData` object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tl.cox_ph`.
adata: :class:`~anndata.AnnData` object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tools.cox_ph`.
uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph` function stored the summary table. See argument `uns_key` in :func:`~ehrapy.tools.cox_ph`.
labels: List of labels for each coefficient, default uses the index of the summary ta
fig_size: Width, height in inches.
Expand Down
118 changes: 118 additions & 0 deletions tests/_scripts/kaplain_meier_create_expected_plots.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import ehrapy as ep\n",
"\n",
"current_notebook_dir = %pwd\n",
"_TEST_IMAGE_PATH = f\"{current_notebook_dir}/../plot/_images\"\n",
"mimic_2 = ep.dt.mimic_2(encoded=False)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"mimic_2[:, [\"censor_flg\"]].X = np.where(mimic_2[:, [\"censor_flg\"]].X == 0, 1, 0)\n",
"groups = mimic_2[:, [\"service_unit\"]].X\n",
"adata_ficu = mimic_2[groups == \"FICU\"]\n",
"adata_micu = mimic_2[groups == \"MICU\"]\n",
"kmf_1 = ep.tl.kaplan_meier(adata_ficu, duration_col=\"mort_day_censored\", event_col=\"censor_flg\", label=\"FICU\")\n",
"kmf_2 = ep.tl.kaplan_meier(adata_micu, duration_col=\"mort_day_censored\", event_col=\"censor_flg\", label=\"MICU\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = ep.pl.kaplan_meier(\n",
" [kmf_1, kmf_2],\n",
" ci_show=[False, False, False],\n",
" color=[\"k\", \"r\"],\n",
" xlim=[0, 750],\n",
" ylim=[0, 1],\n",
" xlabel=\"Days\",\n",
" ylabel=\"Proportion Survived\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"fig.savefig(f\"{_TEST_IMAGE_PATH}/kaplan_meier_expected.png\", dpi=80)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = ep.pl.kaplan_meier(\n",
" [kmf_1, kmf_2],\n",
" ci_show=[False, False, False],\n",
" color=[\"k\", \"r\"],\n",
" xlim=[0, 750],\n",
" ylim=[0, 1],\n",
" xlabel=\"Days\",\n",
" ylabel=\"Proportion Survived\",\n",
" display_survival_statistics=True,\n",
" grid=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"fig.savefig(f\"{_TEST_IMAGE_PATH}/kaplan_meier_table_expected.png\", dpi=80)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file added tests/plot/_images/kaplan_meier_expected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 44 additions & 0 deletions tests/plot/test_survival_analysis.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,55 @@
from pathlib import Path

import numpy as np

import ehrapy as ep

CURRENT_DIR = Path(__file__).parent
_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images"


def test_kaplan_meier(mimic_2, check_same_image):
mimic_2[:, ["censor_flg"]].X = np.where(mimic_2[:, ["censor_flg"]].X == 0, 1, 0)
groups = mimic_2[:, ["service_unit"]].X
adata_ficu = mimic_2[groups == "FICU"]
adata_micu = mimic_2[groups == "MICU"]
kmf_1 = ep.tl.kaplan_meier(adata_ficu, duration_col="mort_day_censored", event_col="censor_flg", label="FICU")
kmf_2 = ep.tl.kaplan_meier(adata_micu, duration_col="mort_day_censored", event_col="censor_flg", label="MICU")
fig, ax = ep.pl.kaplan_meier(
[kmf_1, kmf_2],
ci_show=[False, False, False],
color=["k", "r"],
xlim=[0, 750],
ylim=[0, 1],
xlabel="Days",
ylabel="Proportion Survived",
)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/kaplan_meier",
tol=2e-1,
)

fig, ax = ep.pl.kaplan_meier(
[kmf_1, kmf_2],
ci_show=[False, False, False],
color=["k", "r"],
xlim=[0, 750],
ylim=[0, 1],
xlabel="Days",
ylabel="Proportion Survived",
grid=True,
display_survival_statistics=True,
)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/kaplan_meier_table",
tol=2e-1,
)


def test_coxph_forestplot(mimic_2, check_same_image):
adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg")
Expand Down

0 comments on commit ae9eed7

Please sign in to comment.