Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve survival analysis interface #825

Merged
merged 31 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a970b5b
updated kmf to match method signature
aGuyLearning Nov 13, 2024
7434bde
updated notebook
aGuyLearning Nov 13, 2024
5add2c3
updated ehrapy tutorial commit
aGuyLearning Nov 13, 2024
150a7f7
updated docu for new method signature
aGuyLearning Nov 13, 2024
b66fb44
added outputs to survival analysis
aGuyLearning Nov 13, 2024
0c8e6d6
correctly passing on fitting options
aGuyLearning Nov 13, 2024
95c2b74
pull request fixes.
aGuyLearning Nov 20, 2024
6085e96
added legacy suport
aGuyLearning Nov 20, 2024
16b7d5f
added kmf function legacy support in tests and added new kaplan_meier…
aGuyLearning Nov 27, 2024
579c220
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
82a6e3c
updated notebook
aGuyLearning Nov 27, 2024
c604074
added stacklevel to deprecation warning
aGuyLearning Nov 27, 2024
f6b5a89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
7322f15
added deprecation warning in comment
aGuyLearning Nov 27, 2024
75f8000
Merge branch 'main' into enhancement/issue-822
Zethson Nov 28, 2024
1b80ff1
Update ehrapy/plot/_survival_analysis.py
eroell Dec 1, 2024
a26f6bc
Update ehrapy/plot/_survival_analysis.py
eroell Dec 1, 2024
1442983
Update ehrapy/plot/_survival_analysis.py
eroell Dec 1, 2024
972b71b
Update ehrapy/plot/_survival_analysis.py
eroell Dec 1, 2024
915df91
Update tests/tools/test_sa.py
eroell Dec 1, 2024
a3502b5
doc adjustments
eroell Dec 1, 2024
969eeb9
Merge branch 'main' into enhancement/issue-822
eroell Dec 1, 2024
1940ace
change name of kmf plot to kaplan_meier, some adjustments
eroell Dec 1, 2024
18f9292
introduce keyword only for univariate sa
eroell Dec 1, 2024
8c14039
correct docstring
eroell Dec 1, 2024
6f291bc
update submodule
eroell Dec 1, 2024
08e5949
add lifelines intersphinx mappings
eroell Dec 1, 2024
315e564
Update ehrapy/tools/_sa.py
Zethson Dec 2, 2024
b9e5bfb
Update ehrapy/tools/_sa.py
Zethson Dec 2, 2024
7bbe627
Update ehrapy/tools/_sa.py
Zethson Dec 2, 2024
5e14096
Merge branch 'main' into enhancement/issue-822
Zethson Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
16 changes: 7 additions & 9 deletions ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,20 @@ def kmf(
# So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter

>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
>>> kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg")
eroell marked this conversation as resolved.
Show resolved Hide resolved
>>> ep.pl.kmf(
... [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True
... )

.. image:: /_static/docstring_previews/kmf_plot_1.png

>>> T = adata[:, ["mort_day_censored"]].X
>>> E = adata[:, ["censor_flg"]].X
>>> groups = adata[:, ["service_unit"]].X
>>> ix1 = groups == "FICU"
>>> ix2 = groups == "MICU"
>>> ix3 = groups == "SICU"
>>> kmf_1 = ep.tl.kmf(T[ix1], E[ix1], label="FICU")
>>> kmf_2 = ep.tl.kmf(T[ix2], E[ix2], label="MICU")
>>> kmf_3 = ep.tl.kmf(T[ix3], E[ix3], label="SICU")
>>> adata_ficu = adata[groups == "FICU"]
>>> adata_micu = adata[groups == "MICU"]
>>> adata_sicu = adata[groups == "SICU"]
>>> kmf_1 = ep.tl.kmf(adata_ficu, "mort_day_censored", "censor_flg", label="FICU")
eroell marked this conversation as resolved.
Show resolved Hide resolved
>>> kmf_2 = ep.tl.kmf(adata_micu, "mort_day_censored", "censor_flg", label="MICU")
eroell marked this conversation as resolved.
Show resolved Hide resolved
>>> kmf_3 = ep.tl.kmf(adata_sicu, "mort_day_censored", "censor_flg", label="SICU")
eroell marked this conversation as resolved.
Show resolved Hide resolved
>>> ep.pl.kmf([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'],
>>> xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived")

Expand Down
78 changes: 54 additions & 24 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,10 @@ def glm(


def kmf(
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
Zethson marked this conversation as resolved.
Show resolved Hide resolved
durations: Iterable,
event_observed: Iterable | None = None,
timeline: Iterable = None,
entry: Iterable | None = None,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
label: str | None = None,
alpha: float | None = None,
ci_labels: tuple[str, str] = None,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
weights: Iterable | None = None,
censoring: Literal["right", "left"] = None,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
adata: AnnData,
duration_col: str,
event_col: str | None = None,
**kwargs,
) -> KaplanMeierFitter:
"""Fit the Kaplan-Meier estimate for the survival function.

Expand Down Expand Up @@ -156,24 +151,38 @@ def kmf(
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
>>> kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg", label="Mortality")
"""

kmf = KaplanMeierFitter()
if censoring == "None" or "right":
kmf.fit(
durations=durations,
event_observed=event_observed,
df = anndata_to_df(adata)
T = df[duration_col]
E = df[event_col]

# unpack kwargs
timeline = kwargs.get("timeline", None)
entry = kwargs.get("entry", None)
label = kwargs.get("label", None)
alpha = kwargs.get("alpha", None)
ci_labels = kwargs.get("ci_labels", None)
weights = kwargs.get("weights", None)
censoring = kwargs.get("censoring", "right")

if censoring == "left":
kmf.fit_left_censoring(
durations=T,
event_observed=E,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)
elif censoring == "left":
kmf.fit_left_censoring(
durations=durations,
event_observed=event_observed,
else:
kmf.fit(
durations=T,
event_observed=E,
timeline=timeline,
entry=entry,
label=label,
Expand Down Expand Up @@ -376,7 +385,9 @@ def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_co
)


def _univariate_model(adata: AnnData, duration_col: str, event_col: str, model_class, accept_zero_duration=True):
def _univariate_model(
adata: AnnData, duration_col: str, event_col: str, model_class, accept_zero_duration=True, **kwargs
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
):
"""Convenience function for univariate models."""
df = anndata_to_df(adata)

Expand All @@ -385,13 +396,32 @@ def _univariate_model(adata: AnnData, duration_col: str, event_col: str, model_c
T = df[duration_col]
E = df[event_col]

# unpack kwargs
timeline = kwargs.get("timeline", None)
entry = kwargs.get("entry", None)
label = kwargs.get("label", None)
alpha = kwargs.get("alpha", None)
ci_labels = kwargs.get("ci_labels", None)
weights = kwargs.get("weights", None)
fit_options = kwargs.get("fit_options", None)

model = model_class()
model.fit(T, event_observed=E)
model.fit(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add fit_left_censoring call upon argument input left

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if nelson-aalen gracefully crashes with meaningful error message

T,
event_observed=E,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
fit_options=fit_options,
)

return model


def nelson_aalen(adata: AnnData, duration_col: str, event_col: str) -> NelsonAalenFitter:
def nelson_aalen(adata: AnnData, duration_col: str, event_col: str, **kwargs) -> NelsonAalenFitter:
"""Employ the Nelson-Aalen estimator to estimate the cumulative hazard function from censored survival data

The Nelson-Aalen estimator is a non-parametric method used in survival analysis to estimate the cumulative hazard function.
Expand All @@ -415,10 +445,10 @@ def nelson_aalen(adata: AnnData, duration_col: str, event_col: str) -> NelsonAal
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> naf = ep.tl.nelson_aalen(adata, "mort_day_censored", "censor_flg")
"""
return _univariate_model(adata, duration_col, event_col, NelsonAalenFitter)
return _univariate_model(adata, duration_col, event_col, NelsonAalenFitter, True, **kwargs)


def weibull(adata: AnnData, duration_col: str, event_col: str) -> WeibullFitter:
def weibull(adata: AnnData, duration_col: str, event_col: str, **kwargs) -> WeibullFitter:
"""Employ the Weibull model in univariate survival analysis to understand event occurrence dynamics.

In contrast to the non-parametric Nelson-Aalen estimator, the Weibull model employs a parametric approach with shape and scale parameters,
Expand All @@ -445,4 +475,4 @@ def weibull(adata: AnnData, duration_col: str, event_col: str) -> WeibullFitter:
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> wf = ep.tl.weibull(adata, "mort_day_censored", "censor_flg")
"""
return _univariate_model(adata, duration_col, event_col, WeibullFitter, accept_zero_duration=False)
return _univariate_model(adata, duration_col, event_col, WeibullFitter, accept_zero_duration=False, **kwargs)
2 changes: 1 addition & 1 deletion tests/tools/test_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _sa_func_test(self, sa_function, sa_class, mimic_2_sa):

def test_kmf(self, mimic_2_sa):
adata, _, _ = mimic_2_sa
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg")
self._sa_function_assert(kmf, KaplanMeierFitter)

def test_cox_ph(self, mimic_2_sa):
Expand Down
Loading