Skip to content

Commit

Permalink
Coxphfitter (#643)
Browse files Browse the repository at this point in the history
* add lifeline CoxPHFitter to survival analysis

* add CoxPHFitter survival analysis test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add cox_ph to docs

* import cox_ph function

* add docstring + rename to cox_ph

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* correct ehrapy_anndata import

* update cox_ph and kmf docstring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* debug _sa

* debug entry_col for cox_ph

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix formatting of some survival analysis

Signed-off-by: zethson <[email protected]>

---------

Signed-off-by: zethson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: zethson <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2024
1 parent 8fe3f2d commit 68d6498
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret
tools.kmf
tools.test_kmf_logrank
tools.test_nested_f_statistic
tools.cox_ph
```

### Causal Inference
Expand Down
2 changes: 1 addition & 1 deletion ehrapy/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ehrapy.tools._sa import anova_glm, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic
from ehrapy.tools._sa import anova_glm, cox_ph, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic
from ehrapy.tools._scanpy_tl_api import * # noqa: F403
from ehrapy.tools.causal._dowhy import causal_inference
from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups
Expand Down
68 changes: 53 additions & 15 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from lifelines import KaplanMeierFitter
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.statistics import StatisticalResult, logrank_test
from scipy import stats

from ehrapy.anndata import anndata_to_df

if TYPE_CHECKING:
from collections.abc import Iterable

Expand All @@ -26,13 +28,14 @@ def ols(
"""Create a Ordinary Least Squares (OLS) Model from a formula and AnnData.
See https://www.statsmodels.org/stable/generated/statsmodels.formula.api.ols.html#statsmodels.formula.api.ols
Internally use the statsmodel to create a OLS Model from a formula and dataframe.
Args:
adata: The AnnData object for the OLS model.
var_names: A list of var names indicating which columns are for the OLS model.
formula: The formula specifying the model.
missing: Available options are 'none', 'drop', and 'raise'. If 'none', no nan checking is done. If 'drop', any observations with nans are dropped. If 'raise', an error is raised. Default is 'none'.
missing: Available options are 'none', 'drop', and 'raise'.
If 'none', no nan checking is done. If 'drop', any observations with nans are dropped.
If 'raise', an error is raised. Defaults to 'none'.
Returns:
The OLS model instance.
Expand Down Expand Up @@ -64,7 +67,6 @@ def glm(
"""Create a Generalized Linear Model (GLM) from a formula, a distribution, and AnnData.
See https://www.statsmodels.org/stable/generated/statsmodels.formula.api.glm.html#statsmodels.formula.api.glm
Internally use the statsmodel to create a GLM Model from a formula, a distribution, and dataframe.
Args:
adata: The AnnData object for the GLM model.
Expand All @@ -74,7 +76,7 @@ def glm(
Defaults to 'Gaussian'.
missing: Available options are 'none', 'drop', and 'raise'. If 'none', no nan checking is done.
If 'drop', any observations with nans are dropped. If 'raise', an error is raised (default: 'none').
ascontinus: A list of var names indicating which columns are continuous rather than categorical.
as_continuous: A list of var names indicating which columns are continuous rather than categorical.
The corresponding columns will be set as type float.
Returns:
Expand All @@ -86,7 +88,7 @@ def glm(
>>> formula = 'day_28_flg ~ age'
>>> var_names = ['day_28_flg', 'age']
>>> family = 'Binomial'
>>> glm = ep.tl.glmglm(adata, var_names, formula, family, missing = 'drop', ascontinus = ['age'])
>>> glm = ep.tl.glm(adata, var_names, formula, family, missing = 'drop', ascontinus = ['age'])
"""
family_dict = {
"Gaussian": sm.families.Gaussian(),
Expand Down Expand Up @@ -120,15 +122,18 @@ def kmf(
) -> KaplanMeierFitter:
"""Fit the Kaplan-Meier estimate for the survival function.
See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter
Class for fitting the Kaplan-Meier estimate for the survival function.
The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data.
In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment.
See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator
https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter
Args:
durations: length n -- duration (relative to subject's birth) the subject was alive for.
event_observed: True if the the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed==None.
event_observed: True if the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed==None.
timeline: return the best estimate at the values in timelines (positively increasing)
entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
If None, all members of the population entered study when they were "born".
If None, all members of the population entered study when they were "born".
label: A string to name the column of the estimate.
alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
Expand All @@ -143,9 +148,7 @@ def kmf(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Because in MIMIC-II database, `censor_fl` is censored or death (binary: 0 = death, 1 = censored).
>>> # While in KaplanMeierFitter, `event_observed` is True if the the death was observed, False if the event was lost (right-censored).
>>> # So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ['censor_flg']].X = np.where(adata[:, ['censor_flg']].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata[:, ['mort_day_censored']].X, adata[:, ['censor_flg']].X)
"""
Expand Down Expand Up @@ -184,12 +187,12 @@ def test_kmf_logrank(
) -> StatisticalResult:
"""Calculates the p-value for the logrank test comparing the survival functions of two groups.
See https://lifelines.readthedocs.io/en/latest/lifelines.statistics.html
Measures and reports on whether two intensity processes are different.
That is, given two event series, determines whether the data generating processes are statistically different.
The test-statistic is chi-squared under the null hypothesis.
See https://lifelines.readthedocs.io/en/latest/lifelines.statistics.html
Args:
kmf_A: The first KaplanMeierFitter object containing the durations and events.
kmf_B: The second KaplanMeierFitter object containing the durations and events.
Expand Down Expand Up @@ -262,3 +265,38 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_
}
dataframe = pd.DataFrame(data=table)
return dataframe


def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> CoxPHFitter:
"""Fit the Cox’s proportional hazard for the survival function.
The Cox proportional hazards model (CoxPH) examines the relationship between the survival time of subjects and one or more predictor variables.
It models the hazard rate as a product of a baseline hazard function and an exponential function of the predictors, assuming proportional hazards over time.
See https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html
Args:
adata: adata: AnnData object with necessary columns `duration_col` and `event_col`.
duration_col: the name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: the name of the column in anndata that contains the subjects’ death observation. If left as None, assume all individuals are uncensored.
entry_col: a column denoting when a subject entered the study, i.e. left-truncation.
Returns:
Fitted CoxPHFitter
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ['censor_flg']].X = np.where(adata[:, ['censor_flg']].X == 0, 1, 0)
>>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg")
"""
df = anndata_to_df(adata)
keys = [duration_col, event_col]
if entry_col:
keys.append(entry_col)
df = df[keys]
cph = CoxPHFitter()
cph.fit(df, duration_col, event_col, entry_col=entry_col)

return cph
11 changes: 10 additions & 1 deletion tests/tools/test_sa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import statsmodels
from lifelines import KaplanMeierFitter
from lifelines import CoxPHFitter, KaplanMeierFitter

import ehrapy as ep

Expand Down Expand Up @@ -75,3 +75,12 @@ def test_anova_glm(self):
assert dataframe.shape == (2, 6)
assert dataframe.iloc[1, 4] == 2
assert pytest.approx(dataframe.iloc[1, 5], 0.1) == 0.103185

def test_cox_ph(self):
adata = ep.dt.mimic_2(encoded=False)
adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg")

assert isinstance(cph, CoxPHFitter)
assert len(cph.durations) == 1776
assert sum(cph.event_observed) == 497

0 comments on commit 68d6498

Please sign in to comment.