Skip to content

Commit

Permalink
small comments included
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Feb 29, 2024
1 parent 3673d28 commit 9687e4d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
43 changes: 23 additions & 20 deletions ehrapy/tools/cohort_tracking/_cohort_tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from collections.abc import Iterable
from typing import Any, Union

import graphviz
Expand All @@ -10,10 +11,10 @@
from tableone import TableOne


def _check_columns_exist(df, columns):
if not all(col in df.columns for col in columns):
missing_columns = [col for col in columns if col not in df.columns]
raise ValueError(f"Columns {missing_columns} not found in dataframe.")
def _check_columns_exist(df, columns) -> None:
missing_columns = set(columns) - set(df.columns)
if missing_columns:
raise ValueError(f"Columns {list(missing_columns)} not found in dataframe.")


# from tableone: https://github.com/tompollard/tableone/blob/bfd6fbaa4ed3e9f59e1a75191c6296a2a80ccc64/tableone/tableone.py#L555
Expand All @@ -34,7 +35,9 @@ def _detect_categorical_columns(data) -> list:


class CohortTracker:
def __init__(self, adata: AnnData | pd.DataFrame, columns: list = None, categorical: list = None, *args: Any):
def __init__(
self, adata: AnnData | pd.DataFrame, columns: Iterable = None, categorical: Iterable = None, *args: Any
):
"""Track cohort changes over multiple filtering or processing steps.
This class offers functionality to track and plot cohort changes over multiple filtering or processing steps,
Expand All @@ -43,8 +46,8 @@ def __init__(self, adata: AnnData | pd.DataFrame, columns: list = None, categori
Tightly interacting with the `tableone` package [1].
Args:
adata: :class:`~anndata.AnnData` or :class:`~pandas.DataFrame` object to track.
columns: List of columns to track. If `None`, all columns will be tracked.
categorical: List of columns that contain categorical variables, if not given will be inferred from the data.
columns: Iterable of columns to track. If `None`, all columns will be tracked.
categorical: Iterable of columns that contain categorical variables, if not given will be inferred from the data.
References
----------
Expand Down Expand Up @@ -80,7 +83,7 @@ def __init__(self, adata: AnnData | pd.DataFrame, columns: list = None, categori

def __call__(
self, adata: AnnData, label: str = None, operations_done: str = None, *args: Any, **tableone_kwargs: Any
) -> Any:
) -> None:
if isinstance(adata, AnnData):
df = adata.obs
elif isinstance(adata, pd.DataFrame):
Expand Down Expand Up @@ -151,9 +154,9 @@ def plot_cohort_change(
self,
set_axis_labels=True,
subfigure_title: bool = False,
sns_color_palette: str = "husl",
color_palette: str = "husl",
save: str = None,
return_plot: bool = False,
return_figure: bool = False,
subplots_kwargs: dict = None,
legend_kwargs: dict = None,
):
Expand All @@ -164,16 +167,16 @@ def plot_cohort_change(
Args:
set_axis_labels: If `True`, the y-axis labels will be set to the column names.
subfigure_title: If `True`, each subplot will have a title with the `label` provided during tracking.
sns_color_palette: The color palette to use for the plot. Default is "husl".
color_palette: The color palette to use for the plot. Default is "husl".
save: If a string is provided, the plot will be saved to the path specified.
return_plot: If `True`, the plot will be returned as a tuple of (fig, ax).
return_figure: If `True`, the plot will be returned as a tuple of (fig, ax).
subplot_kwargs: Additional keyword arguments for the subplots.
legend_kwargs: Additional keyword arguments for the legend.
Returns:
If `return_plot` a :class:`~matplotlib.figure.Figure` and a :class:`~matplotlib.axes.Axes` or a list of it.
If `return_figure` a :class:`~matplotlib.figure.Figure` and a :class:`~matplotlib.axes.Axes` or a list of it.
Example:
Examples:
.. code-block:: python
import ehrapy as ep
Expand Down Expand Up @@ -211,7 +214,7 @@ def plot_cohort_change(

# Adjust the hue shift based on the category position such that the colors are more distinguishable
hue_shift = (pos + 1) / len(data)
colors = sns.color_palette(sns_color_palette, len(data))
colors = sns.color_palette(color_palette, len(data))
adjusted_colors = [((color[0] + hue_shift) % 1, color[1], color[2]) for color in colors]

# for categoricals, plot multiple bars
Expand Down Expand Up @@ -276,24 +279,24 @@ def plot_cohort_change(
save,
)

if return_plot:
if return_figure:
return fig, axes

else:
plt.tight_layout()
plt.show()

def plot_flowchart(self, save: str = None, return_plot: bool = True):
def plot_flowchart(self, save: str = None, return_figure: bool = True):
"""Flowchart over the tracked steps.
Create a simple flowchart of data preparation steps tracked with `CohortTracker`.
Args:
save: If a string is provided, the plot will be saved to the path specified.
return_plot: If `True`, the plot will be returned as a :class:`~graphviz.Digraph`.
return_figure: If `True`, the plot will be returned as a :class:`~graphviz.Digraph`.
Returns:
If `return_plot` a :class:`~graphviz.Digraph`.
If `return_figure` a :class:`~graphviz.Digraph`.
Example:
.. code-block:: python
Expand Down Expand Up @@ -328,5 +331,5 @@ def plot_flowchart(self, save: str = None, return_plot: bool = True):
dot.render(save, format="png", cleanup=True)

# Think that to be shown, the plot can a) be rendered (as above) or be "printed" by the notebook
if return_plot:
if return_figure:
return dot
3 changes: 1 addition & 2 deletions tests/tools/cohort_tracking/test_cohort_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest

import ehrapy as ep
import ehrapy.tools.feature_ranking._rank_features_groups as _utils
from ehrapy.io._read import read_csv

CURRENT_DIR = Path(__file__).parent
Expand Down Expand Up @@ -166,4 +165,4 @@ def test_CohortTracker_plot_cohort_change(self):
ct(adata)
ct(adata)

ct.plot_cohort_change(return_plot=True)
ct.plot_cohort_change(return_figure=True)

0 comments on commit 9687e4d

Please sign in to comment.