From 204a3a7bd64c9b75c73d3313daff673bcaa72663 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Mon, 18 Dec 2023 17:41:38 +0100 Subject: [PATCH] Refactor (#629) Signed-off-by: zethson --- ehrapy/plot/__init__.py | 1 - ehrapy/plot/_qc.py | 58 ---- ehrapy/preprocessing/_encode.py | 5 +- ehrapy/preprocessing/_quality_control.py | 44 +-- ehrapy/tools/__init__.py | 2 +- ehrapy/tools/_scanpy_tl_api.py | 49 --- .../feature_ranking/_rank_features_groups.py | 59 +++- tests/preprocessing/test_imputation.py | 303 ++++++++---------- ...ents.py => test_summarize_measurements.py} | 0 9 files changed, 206 insertions(+), 315 deletions(-) delete mode 100644 ehrapy/plot/_qc.py rename tests/preprocessing/{test_expand_measurements.py => test_summarize_measurements.py} (100%) diff --git a/ehrapy/plot/__init__.py b/ehrapy/plot/__init__.py index 807f20a0..102c2f26 100644 --- a/ehrapy/plot/__init__.py +++ b/ehrapy/plot/__init__.py @@ -1,5 +1,4 @@ from ehrapy.plot._missingno_pl_api import * # noqa: F403 -from ehrapy.plot._qc import qc_metrics from ehrapy.plot._scanpy_pl_api import * # noqa: F403 from ehrapy.plot._survival_analysis import kmf, ols from ehrapy.plot._util import * # noqa: F403 diff --git a/ehrapy/plot/_qc.py b/ehrapy/plot/_qc.py deleted file mode 100644 index fe6b741a..00000000 --- a/ehrapy/plot/_qc.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from rich.console import Console -from rich.table import Table - -if TYPE_CHECKING: - from anndata import AnnData - - -def qc_metrics(adata: AnnData, extra_columns: list[str] | None = None) -> None: # pragma: no cover - """Plots the calculated quality control metrics for var of adata. - - Per default this will display the following features: - ``missing_values_abs``, ``missing_values_pct``, ``mean``, ``median``, ``standard_deviation``, ``max``, ``min``. - - Args: - adata: Annotated data matrix. - extra_columns: List of custom (qc) var columns to be displayed additionally. - - Examples: - >>> import ehrapy as ep - >>> adata = ep.dt.mimic_2(encoded=True) - >>> ep.pp.qc_metrics(adata) - >>> ep.pl.qc_metrics(adata) - """ - table = Table(title="[bold blue]Ehrapy qc metrics of var") - # add special column header for the column name - table.add_column("[bold blue]Column name", justify="right", style="bold green") - var_names = list(adata.var_names) - # default qc columns added to var - fixed_qc_columns = [ - "missing_values_abs", - "missing_values_pct", - "mean", - "median", - "standard_deviation", - "min", - "max", - ] - # update columns to display with extra columns (if any) - columns_to_display = fixed_qc_columns if not extra_columns else fixed_qc_columns + extra_columns - # check whether all columns exist (qc has been executed before and extra columns are var columns) - if (set(columns_to_display) & set(adata.var.columns)) != set(columns_to_display): - raise AttributeError( - "Cannot display QC metrics of current AnnData object. Either QC has not been executed before or " - "some column(s) of the extra_columns parameter are not in var!" - ) - vars_to_display = adata.var[columns_to_display] - # add column headers - for col in vars_to_display: - table.add_column(f"[bold blue]{col}", justify="right", style="bold green") - for var in range(len(vars_to_display)): - table.add_row(var_names[var], *map(str, list(vars_to_display.iloc[var]))) - - console = Console() - console.print(table) diff --git a/ehrapy/preprocessing/_encode.py b/ehrapy/preprocessing/_encode.py index 3463edf3..7abb4380 100644 --- a/ehrapy/preprocessing/_encode.py +++ b/ehrapy/preprocessing/_encode.py @@ -114,8 +114,6 @@ def undo_encoding( else: raise ValueError(f"Cannot decode object of type {type(data)}. Can only decode AnnData objects!") - return None - def _encode( adata: AnnData, @@ -787,8 +785,7 @@ def _reorder_encodings(adata: AnnData, new_encodings: dict[str, list[list[str]] # if encoding mode is if not encoded_categoricals_with_mode: del adata.uns["encoding_to_var"][encode_mode] - logg.info("Re-encoded the AnnData object.") - # return updated encodings + return _update_new_encode_modes(new_encodings, adata.uns["encoding_to_var"]) diff --git a/ehrapy/preprocessing/_quality_control.py b/ehrapy/preprocessing/_quality_control.py index 7daac9cd..c9d6a32c 100644 --- a/ehrapy/preprocessing/_quality_control.py +++ b/ehrapy/preprocessing/_quality_control.py @@ -9,8 +9,6 @@ from rich import print from thefuzz import process -from ehrapy import logging as logg - if TYPE_CHECKING: from collections.abc import Collection @@ -36,36 +34,25 @@ def qc_metrics( Observation level metrics include: - `missing_values_abs` - Absolute amount of missing values. - `missing_values_pct` - Relative amount of missing values in percent. + - `missing_values_abs`: Absolute amount of missing values. + - `missing_values_pct`: Relative amount of missing values in percent. Feature level metrics include: - `missing_values_abs` - Absolute amount of missing values. - `missing_values_pct` - Relative amount of missing values in percent. - `mean` - Mean value of the features. - `median` - Median value of the features. - `std` - Standard deviation of the features. - `min` - Minimum value of the features. - `max` - Maximum value of the features. + - `missing_values_abs`: Absolute amount of missing values. + - `missing_values_pct`: Relative amount of missing values in percent. + - `mean`: Mean value of the features. + - `median`: Median value of the features. + - `std`: Standard deviation of the features. + - `min`: Minimum value of the features. + - `max`: Maximum value of the features. Examples: >>> import ehrapy as ep >>> import seaborn as sns - >>> import matplotlib.pyplot as plt >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.qc_metrics(adata) >>> sns.displot(adata.obs["missing_values_abs"]) - >>> plt.show() """ obs_metrics = _obs_qc_metrics(adata, layer, qc_vars) var_metrics = _var_qc_metrics(adata, layer) @@ -73,7 +60,6 @@ def qc_metrics( if inplace: adata.obs[obs_metrics.columns] = obs_metrics adata.var[var_metrics.columns] = var_metrics - logg.info("Added the calculated metrics to AnnData's `obs` and `var`.") return obs_metrics, var_metrics @@ -91,10 +77,8 @@ def _missing_values( Returns: Absolute or relative amount of missing values. """ - # Absolute number of missing values if shape is None: return pd.isnull(arr).sum() - # Relative number of missing values in percent else: n_rows, n_cols = shape if df_type == "obs": @@ -256,7 +240,7 @@ def qc_lab_measurements( If you want to specify your own table as a Pandas DataFrame please examine the existing default table. Ethnicity and age columns can be added. - https://github.com/theislab/ehrapy/ehrapy/preprocessing/laboratory_reference_tables/laposata.tsv + https://github.com/theislab/ehrapy/blob/main/ehrapy/preprocessing/laboratory_reference_tables/laposata.tsv Args: adata: Annotated data matrix. @@ -267,13 +251,13 @@ def qc_lab_measurements( threshold: Minimum required matching confidence score of the fuzzysearch. 0 = no matches, 100 = all must match. Defaults to 20. age_col: Column containing age values. - age_range: The inclusive age-range to filter for. e.g. 5-99 + age_range: The inclusive age-range to filter for such as 5-99. sex_col: Column containing sex values. Column must contain 'U', 'M' or 'F'. sex: Sex to filter the reference values for. Use U for unisex which uses male values when male and female conflict. - Defaults to 'U|M' + Defaults to 'U|M'. ethnicity_col: Column containing ethnicity values. ethnicity: Ethnicity to filter for. - copy: Whether to return a copy. Defaults to False . + copy: Whether to return a copy. Defaults to False. verbose: Whether to have verbose stdout. Notifies user of matched columns and value ranges. Returns: @@ -323,7 +307,6 @@ def qc_lab_measurements( f"ethnicity columns and their values." ) - # Fetch reference values try: if age_col: min_age, max_age = age_range.split("-") @@ -344,7 +327,6 @@ def qc_lab_measurements( except TypeError: print(f"[bold yellow]Unable to find specified reference values for {measurement}.") - # Check whether the measurements are inside the reference ranges check = reference_values[reference_column].values check_str: str = np.array2string(check) check_str = check_str.replace("[", "").replace("]", "").replace("'", "") diff --git a/ehrapy/tools/__init__.py b/ehrapy/tools/__init__.py index 20f0d699..0468a346 100644 --- a/ehrapy/tools/__init__.py +++ b/ehrapy/tools/__init__.py @@ -1,7 +1,7 @@ from ehrapy.tools._sa import anova_glm, glm, kmf, ols, test_kmf_logrank, test_nested_f_statistic from ehrapy.tools._scanpy_tl_api import * # noqa: F403 from ehrapy.tools.causal._dowhy import causal_inference -from ehrapy.tools.feature_ranking._rank_features_groups import rank_features_groups +from ehrapy.tools.feature_ranking._rank_features_groups import filter_rank_features_groups, rank_features_groups try: # pragma: no cover from ehrapy.tools.nlp._medcat import ( diff --git a/ehrapy/tools/_scanpy_tl_api.py b/ehrapy/tools/_scanpy_tl_api.py index 2780eb00..477768f1 100644 --- a/ehrapy/tools/_scanpy_tl_api.py +++ b/ehrapy/tools/_scanpy_tl_api.py @@ -679,52 +679,3 @@ def ingest( inplace=inplace, **kwargs, ) - - -def filter_rank_features_groups( - adata: AnnData, - key="rank_features_groups", - groupby=None, - key_added="rank_features_groups_filtered", - min_in_group_fraction=0.25, - min_fold_change=1, - max_out_group_fraction=0.5, -) -> None: # pragma: no cover - """Filters out features based on fold change and fraction of features containing the feature within and outside the `groupby` categories. - - See :func:`~ehrapy.tl.rank_features_groups`. - - Results are stored in `adata.uns[key_added]` - (default: 'rank_genes_groups_filtered'). - - To preserve the original structure of adata.uns['rank_genes_groups'], - filtered genes are set to `NaN`. - - Args: - adata: Annotated data matrix. - key: Key previously added by :func:`~ehrapy.tl.rank_features_groups` - groupby: The key of the observations grouping to consider. - key_added: The key in `adata.uns` information is saved to. - min_in_group_fraction: Minimum in group fraction (default: 0.25). - min_fold_change: Miniumum fold change (default: 1). - max_out_group_fraction: Maximum out group fraction (default: 0.5). - - Returns: - Same output as :func:`ehrapy.tl.rank_features_groups` but with filtered feature names set to `nan` - - Examples: - >>> import ehrapy as ep - >>> adata = ep.dt.mimic_2(encoded=True) - >>> ep.tl.rank_features_groups(adata, "service_unit") - >>> ep.pl.rank_features_groups(adata) - """ - return sc.tl.filter_rank_genes_groups( - adata=adata, - key=key, - groupby=groupby, - use_raw=False, - key_added=key_added, - min_in_group_fraction=min_in_group_fraction, - min_fold_change=min_fold_change, - max_out_group_fraction=max_out_group_fraction, - ) diff --git a/ehrapy/tools/feature_ranking/_rank_features_groups.py b/ehrapy/tools/feature_ranking/_rank_features_groups.py index cb74f4c1..c35088ad 100644 --- a/ehrapy/tools/feature_ranking/_rank_features_groups.py +++ b/ehrapy/tools/feature_ranking/_rank_features_groups.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd @@ -346,8 +346,11 @@ def rank_features_groups( Used only for statistical tests (e.g. doesn't work for "logreg" `num_cols_method`) tie_correct: Use tie correction for `'wilcoxon'` scores. Used only for `'wilcoxon'`. layer: Key from `adata.layers` whose value will be used to perform tests on. - field_to_rank: Set to `layer` to rank variables in `adata.X` or `adata.layers[layer]` (default), `obs` to rank `adata.obs`, or `layer_and_obs` to rank both. Layer needs to be None if this is not 'layer'. - columns_to_rank: Subset of columns to rank. If 'all', all columns are used. If a dictionary, it must have keys 'var_names' and/or 'obs_names' and values must be iterables of strings. E.g. {'var_names': ['glucose'], 'obs_names': ['age', 'height']}. + field_to_rank: Set to `layer` to rank variables in `adata.X` or `adata.layers[layer]` (default), `obs` to rank `adata.obs`, or `layer_and_obs` to rank both. + Layer needs to be None if this is not 'layer'. + columns_to_rank: Subset of columns to rank. If 'all', all columns are used. + If a dictionary, it must have keys 'var_names' and/or 'obs_names' and values must be iterables of strings + such as {'var_names': ['glucose'], 'obs_names': ['age', 'height']}. **kwds: Are passed to test methods. Currently this affects only parameters that are passed to :class:`sklearn.linear_model.LogisticRegression`. For instance, you can pass `penalty='l1'` to try to come up with a @@ -568,7 +571,6 @@ def rank_features_groups( adata_orig.uns[key_added] = adata.uns[key_added] adata = adata_orig - # Adjust p values if "pvals" in adata.uns[key_added]: adata.uns[key_added]["pvals_adj"] = _adjust_pvalues( adata.uns[key_added]["pvals"], corr_method=correction_method @@ -581,3 +583,52 @@ def rank_features_groups( _sort_features(adata, key_added) return adata if copy else None + + +def filter_rank_features_groups( + adata: AnnData, + key="rank_features_groups", + groupby=None, + key_added="rank_features_groups_filtered", + min_in_group_fraction=0.25, + min_fold_change=1, + max_out_group_fraction=0.5, +) -> None: # pragma: no cover + """Filters out features based on fold change and fraction of features containing the feature within and outside the `groupby` categories. + + See :func:`~ehrapy.tl.rank_features_groups`. + + Results are stored in `adata.uns[key_added]` + (default: 'rank_genes_groups_filtered'). + + To preserve the original structure of adata.uns['rank_genes_groups'], + filtered genes are set to `NaN`. + + Args: + adata: Annotated data matrix. + key: Key previously added by :func:`~ehrapy.tl.rank_features_groups` + groupby: The key of the observations grouping to consider. + key_added: The key in `adata.uns` information is saved to. + min_in_group_fraction: Minimum in group fraction (default: 0.25). + min_fold_change: Miniumum fold change (default: 1). + max_out_group_fraction: Maximum out group fraction (default: 0.5). + + Returns: + Same output as :func:`ehrapy.tl.rank_features_groups` but with filtered feature names set to `nan` + + Examples: + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=True) + >>> ep.tl.rank_features_groups(adata, "service_unit") + >>> ep.pl.rank_features_groups(adata) + """ + return sc.tl.filter_rank_genes_groups( + adata=adata, + key=key, + groupby=groupby, + use_raw=False, + key_added=key_added, + min_in_group_fraction=min_in_group_fraction, + min_fold_change=min_fold_change, + max_out_group_fraction=max_out_group_fraction, + ) diff --git a/tests/preprocessing/test_imputation.py b/tests/preprocessing/test_imputation.py index 9c4baf93..ddf49dd4 100644 --- a/tests/preprocessing/test_imputation.py +++ b/tests/preprocessing/test_imputation.py @@ -24,363 +24,332 @@ _TEST_PATH = f"{CURRENT_DIR}/test_data_imputation" -def test_mean_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - simple_impute(adata) +@pytest.fixture +def impute_num_adata(): + return read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - assert not np.isnan(adata.X).any() +@pytest.fixture +def impute_adata(): + return read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") -def test_mean_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = simple_impute(adata, copy=True) - assert id(adata) != id(adata_imputed) - assert not np.isnan(adata_imputed.X).any() +@pytest.fixture +def impute_iris(): + return read_csv(dataset_path=f"{_TEST_PATH}/test_impute_iris.csv") + + +@pytest.fixture +def impute_titanic(): + return read_csv(dataset_path=f"{_TEST_PATH}/test_impute_titanic.csv") + + +def test_mean_impute_no_copy(impute_num_adata): + simple_impute(impute_num_adata) + assert not np.isnan(impute_num_adata.X).any() -def test_mean_impute_throws_error_non_numerical(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") +def test_mean_impute_copy(impute_num_adata): + adata_imputed = simple_impute(impute_num_adata, copy=True) + + assert id(impute_num_adata) != id(adata_imputed) + assert not np.isnan(adata_imputed.X).any() + + +def test_mean_impute_throws_error_non_numerical(impute_adata): with pytest.raises(ValueError): - simple_impute(adata) + simple_impute(impute_adata) -def test_mean_impute_subset(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = simple_impute(adata, var_names=["intcol", "indexcol"], copy=True) +def test_mean_impute_subset(impute_adata): + adata_imputed = simple_impute(impute_adata, var_names=["intcol", "indexcol"], copy=True) assert not np.all([item != item for item in adata_imputed.X[::, 1:2]]) assert np.any([item != item for item in adata_imputed.X[::, 3:4]]) -def test_median_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - simple_impute(adata, strategy="median") +def test_median_impute_no_copy(impute_num_adata): + simple_impute(impute_num_adata, strategy="median") - assert not np.isnan(adata.X).any() + assert not np.isnan(impute_num_adata.X).any() -def test_median_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = simple_impute(adata, strategy="median", copy=True) +def test_median_impute_copy(impute_num_adata): + adata_imputed = simple_impute(impute_num_adata, strategy="median", copy=True) - assert id(adata) != id(adata_imputed) + assert id(impute_adata) != id(adata_imputed) assert not np.isnan(adata_imputed.X).any() -def test_median_impute_throws_error_non_numerical(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - +def test_median_impute_throws_error_non_numerical(impute_adata): with pytest.raises(ValueError): - simple_impute(adata, strategy="median") + simple_impute(impute_adata, strategy="median") -def test_median_impute_subset(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = simple_impute(adata, var_names=["intcol", "indexcol"], strategy="median", copy=True) +def test_median_impute_subset(impute_adata): + adata_imputed = simple_impute(impute_adata, var_names=["intcol", "indexcol"], strategy="median", copy=True) assert not np.all([item != item for item in adata_imputed.X[::, 1:2]]) assert np.any([item != item for item in adata_imputed.X[::, 3:4]]) -def test_most_frequent_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - simple_impute(adata, strategy="most_frequent") +def test_most_frequent_impute_no_copy(impute_adata): + simple_impute(impute_adata, strategy="most_frequent") - assert not (np.all([item != item for item in adata.X])) + assert not (np.all([item != item for item in impute_adata.X])) -def test_most_frequent_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = simple_impute(adata, strategy="most_frequent", copy=True) +def test_most_frequent_impute_copy(impute_adata): + adata_imputed = simple_impute(impute_adata, strategy="most_frequent", copy=True) - assert id(adata) != id(adata_imputed) + assert id(impute_adata) != id(adata_imputed) assert not (np.all([item != item for item in adata_imputed.X])) -def test_most_frequent_impute_subset(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = simple_impute(adata, var_names=["intcol", "strcol"], strategy="most_frequent", copy=True) +def test_most_frequent_impute_subset(impute_adata): + adata_imputed = simple_impute(impute_adata, var_names=["intcol", "strcol"], strategy="most_frequent", copy=True) assert not (np.all([item != item for item in adata_imputed.X[::, 1:3]])) -def test_knn_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - knn_impute(adata) +def test_knn_impute_no_copy(impute_num_adata): + knn_impute(impute_num_adata) - assert not (np.all([item != item for item in adata.X])) + assert not (np.all([item != item for item in impute_num_adata.X])) -def test_knn_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = knn_impute(adata, n_neighbours=3, copy=True) +def test_knn_impute_copy(impute_num_adata): + adata_imputed = knn_impute(impute_num_adata, n_neighbours=3, copy=True) - assert id(adata) != id(adata_imputed) + assert id(impute_num_adata) != id(adata_imputed) assert not (np.all([item != item for item in adata_imputed.X])) -def test_knn_impute_non_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = knn_impute(adata, n_neighbours=3, copy=True) +def test_knn_impute_non_numerical_data(impute_adata): + adata_imputed = knn_impute(impute_adata, n_neighbours=3, copy=True) assert not (np.all([item != item for item in adata_imputed.X])) -def test_knn_impute_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = knn_impute(adata, copy=True) +def test_knn_impute_numerical_data(impute_num_adata): + adata_imputed = knn_impute(impute_num_adata, copy=True) assert not (np.all([item != item for item in adata_imputed.X])) -def test_knn_impute_list_str(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = knn_impute(adata, var_names=["intcol", "strcol", "boolcol"], copy=True) +def test_knn_impute_list_str(impute_adata): + adata_imputed = knn_impute(impute_adata, var_names=["intcol", "strcol", "boolcol"], copy=True) assert not (np.all([item != item for item in adata_imputed.X])) -def test_missforest_impute_non_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = miss_forest_impute(adata, copy=True) +def test_missforest_impute_non_numerical_data(impute_adata): + adata_imputed = miss_forest_impute(impute_adata, copy=True) assert not (np.all([item != item for item in adata_imputed.X])) -def test_missforest_impute_numerical_data(): +def test_missforest_impute_numerical_data(impute_num_adata): warnings.filterwarnings("ignore", category=ConvergenceWarning) - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = miss_forest_impute(adata, copy=True) + adata_imputed = miss_forest_impute(impute_num_adata, copy=True) assert not (np.all([item != item for item in adata_imputed.X])) -def test_missforest_impute_subset(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") +def test_missforest_impute_subset(impute_num_adata): adata_imputed = miss_forest_impute( - adata, var_names={"non_numerical": ["intcol"], "numerical": ["strcol"]}, copy=True + impute_num_adata, var_names={"non_numerical": ["intcol"], "numerical": ["strcol"]}, copy=True ) assert not (np.all([item != item for item in adata_imputed.X])) -def test_missforest_impute_list_str(): +def test_missforest_impute_list_str(impute_num_adata): warnings.filterwarnings("ignore", category=ConvergenceWarning) - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = miss_forest_impute(adata, var_names=["col1", "col2", "col3"], copy=True) + adata_imputed = miss_forest_impute(impute_num_adata, var_names=["col1", "col2", "col3"], copy=True) assert not (np.all([item != item for item in adata_imputed.X])) -def test_missforest_impute_dict(): +def test_missforest_impute_dict(impute_adata): warnings.filterwarnings("ignore", category=ConvergenceWarning) - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") adata_imputed = miss_forest_impute( - adata, var_names={"numerical": ["intcol", "datetime"], "non_numerical": ["strcol", "boolcol"]}, copy=True + impute_adata, var_names={"numerical": ["intcol", "datetime"], "non_numerical": ["strcol", "boolcol"]}, copy=True ) assert not (np.all([item != item for item in adata_imputed.X])) -def test_soft_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = soft_impute(adata) +def test_soft_impute_no_copy(impute_num_adata): + adata_imputed = soft_impute(impute_num_adata) - assert id(adata) == id(adata_imputed) + assert id(impute_num_adata) == id(adata_imputed) -def test_soft_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = soft_impute(adata, copy=True) +def test_soft_impute_copy(impute_num_adata): + adata_imputed = soft_impute(impute_num_adata, copy=True) - assert id(adata) != id(adata_imputed) + assert id(impute_num_adata) != id(adata_imputed) -def test_soft_impute_non_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = soft_impute(adata) +def test_soft_impute_non_numerical_data(impute_adata): + adata_imputed = soft_impute(impute_adata) assert not (np.all([item != item for item in adata_imputed.X])) -def test_soft_impute_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = soft_impute(adata) +def test_soft_impute_numerical_data(impute_num_adata): + adata_imputed = soft_impute(impute_num_adata) assert not (np.all([item != item for item in adata_imputed.X])) -def test_soft_impute_list_str(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = soft_impute(adata, var_names=["intcol", "strcol", "boolcol"]) +def test_soft_impute_list_str(impute_adata): + adata_imputed = soft_impute(impute_adata, var_names=["intcol", "strcol", "boolcol"]) assert not (np.all([item != item for item in adata_imputed.X])) -def test_IterativeSVD_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = iterative_svd_impute(adata, rank=2) +def test_IterativeSVD_impute_no_copy(impute_num_adata): + adata_imputed = iterative_svd_impute(impute_num_adata, rank=2) - assert id(adata) == id(adata_imputed) + assert id(impute_num_adata) == id(adata_imputed) -def test_IterativeSVD_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = iterative_svd_impute(adata, rank=2, copy=True) +def test_IterativeSVD_impute_copy(impute_num_adata): + adata_imputed = iterative_svd_impute(impute_num_adata, rank=2, copy=True) - assert id(adata) != id(adata_imputed) + assert id(impute_adata) != id(adata_imputed) -def test_IterativeSVD_impute_non_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = iterative_svd_impute(adata, rank=3) +def test_IterativeSVD_impute_non_numerical_data(impute_adata): + adata_imputed = iterative_svd_impute(impute_adata, rank=3) assert not (np.all([item != item for item in adata_imputed.X])) -def test_IterativeSVD_impute_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = iterative_svd_impute(adata, rank=2) +def test_IterativeSVD_impute_numerical_data(impute_num_adata): + adata_imputed = iterative_svd_impute(impute_num_adata, rank=2) assert not (np.all([item != item for item in adata_imputed.X])) -def test_IterativeSVD_impute_list_str(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = iterative_svd_impute(adata, var_names=["intcol", "strcol", "boolcol"], rank=2) +def test_IterativeSVD_impute_list_str(impute_adata): + adata_imputed = iterative_svd_impute(impute_adata, var_names=["intcol", "strcol", "boolcol"], rank=2) assert not (np.all([item != item for item in adata_imputed.X])) -def test_matrix_factorization_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = matrix_factorization_impute(adata) +def test_matrix_factorization_impute_no_copy(impute_num_adata): + adata_imputed = matrix_factorization_impute(impute_num_adata) - assert id(adata) == id(adata_imputed) + assert id(impute_num_adata) == id(adata_imputed) -def test_matrix_factorization_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = matrix_factorization_impute(adata, copy=True) +def test_matrix_factorization_impute_copy(impute_num_adata): + adata_imputed = matrix_factorization_impute(impute_num_adata, copy=True) - assert id(adata) != id(adata_imputed) + assert id(impute_num_adata) != id(adata_imputed) -def test_matrix_factorization_impute_non_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = matrix_factorization_impute(adata) +def test_matrix_factorization_impute_non_numerical_data(impute_adata): + adata_imputed = matrix_factorization_impute(impute_adata) assert not (np.all([item != item for item in adata_imputed.X])) -def test_matrix_factorization_impute_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = matrix_factorization_impute(adata) +def test_matrix_factorization_impute_numerical_data(impute_adata): + adata_imputed = matrix_factorization_impute(impute_adata) assert not (np.all([item != item for item in adata_imputed.X])) -def test_matrix_factorization_impute_list_str(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = matrix_factorization_impute(adata, var_names=["intcol", "strcol", "boolcol"]) +def test_matrix_factorization_impute_list_str(impute_adata): + adata_imputed = matrix_factorization_impute(impute_adata, var_names=["intcol", "strcol", "boolcol"]) assert not (np.all([item != item for item in adata_imputed.X])) -def test_nuclear_norm_minimization_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = nuclear_norm_minimization_impute(adata) +def test_nuclear_norm_minimization_impute_no_copy(impute_num_adata): + adata_imputed = nuclear_norm_minimization_impute(impute_num_adata) - assert id(adata) == id(adata_imputed) + assert id(impute_num_adata) == id(adata_imputed) -def test_nuclear_norm_minimization_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = nuclear_norm_minimization_impute(adata, copy=True) +def test_nuclear_norm_minimization_impute_copy(impute_num_adata): + adata_imputed = nuclear_norm_minimization_impute(impute_num_adata, copy=True) - assert id(adata) != id(adata_imputed) + assert id(impute_num_adata) != id(adata_imputed) -def test_nuclear_norm_minimization_impute_non_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = nuclear_norm_minimization_impute(adata) +def test_nuclear_norm_minimization_impute_non_numerical_data(impute_adata): + adata_imputed = nuclear_norm_minimization_impute(impute_adata) assert not (np.all([item != item for item in adata_imputed.X])) -def test_nuclear_norm_minimization_impute_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = nuclear_norm_minimization_impute(adata) +def test_nuclear_norm_minimization_impute_numerical_data(impute_num_adata): + adata_imputed = nuclear_norm_minimization_impute(impute_num_adata) assert not (np.all([item != item for item in adata_imputed.X])) -def test_nuclear_norm_minimization_impute_list_str(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = nuclear_norm_minimization_impute(adata, var_names=["intcol", "strcol", "boolcol"]) +def test_nuclear_norm_minimization_impute_list_str(impute_adata): + adata_imputed = nuclear_norm_minimization_impute(impute_adata, var_names=["intcol", "strcol", "boolcol"]) assert not (np.all([item != item for item in adata_imputed.X])) @pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.") -def test_miceforest_impute_no_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_iris.csv") - adata_imputed = mice_forest_impute(adata) +def test_miceforest_impute_no_copy(impute_iris): + adata_imputed = mice_forest_impute(impute_iris) - assert id(adata) == id(adata_imputed) + assert id(impute_iris) == id(adata_imputed) @pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.") -def test_miceforest_impute_copy(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_iris.csv") - adata_imputed = mice_forest_impute(adata, copy=True) +def test_miceforest_impute_copy(impute_iris): + adata_imputed = mice_forest_impute(impute_iris, copy=True) - assert id(adata) != id(adata_imputed) + assert id(impute_iris) != id(adata_imputed) @pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.") -def test_miceforest_impute_non_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_titanic.csv") - adata_imputed = mice_forest_impute(adata) +def test_miceforest_impute_non_numerical_data(impute_titanic): + adata_imputed = mice_forest_impute(impute_titanic) assert not (np.all([item != item for item in adata_imputed.X])) @pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.") -def test_miceforest_impute_numerical_data(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_iris.csv") - adata_imputed = mice_forest_impute(adata) +def test_miceforest_impute_numerical_data(impute_iris): + adata_imputed = mice_forest_impute(impute_iris) assert not (np.all([item != item for item in adata_imputed.X])) @pytest.mark.skipif(os.name == "posix", reason="miceforest Imputation not supported by MacOS.") -def test_miceforest_impute_list_str(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_titanic.csv") - adata_imputed = mice_forest_impute(adata, var_names=["Cabin", "Age"]) +def test_miceforest_impute_list_str(impute_titanic): + adata_imputed = mice_forest_impute(impute_titanic, var_names=["Cabin", "Age"]) assert not (np.all([item != item for item in adata_imputed.X])) -def test_explicit_impute_all(): +def test_explicit_impute_all(impute_num_adata): warnings.filterwarnings("ignore", category=FutureWarning) - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - adata_imputed = explicit_impute(adata, replacement=1011, copy=True) + adata_imputed = explicit_impute(impute_num_adata, replacement=1011, copy=True) assert (adata_imputed.X == 1011).sum() == 3 -def test_explicit_impute_subset(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - adata_imputed = explicit_impute(adata, replacement={"strcol": "REPLACED", "intcol": 1011}, copy=True) +def test_explicit_impute_subset(impute_adata): + adata_imputed = explicit_impute(impute_adata, replacement={"strcol": "REPLACED", "intcol": 1011}, copy=True) assert (adata_imputed.X == 1011).sum() == 1 assert (adata_imputed.X == "REPLACED").sum() == 1 -def test_warning(): - adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - warning_results = _warn_imputation_threshold(adata, threshold=20, var_names=None) +def test_warning(impute_num_adata): + warning_results = _warn_imputation_threshold(impute_num_adata, threshold=20, var_names=None) assert warning_results == {"col1": 25, "col3": 50} diff --git a/tests/preprocessing/test_expand_measurements.py b/tests/preprocessing/test_summarize_measurements.py similarity index 100% rename from tests/preprocessing/test_expand_measurements.py rename to tests/preprocessing/test_summarize_measurements.py