From 425b054d8eb01b2e1e809a95bb2a5f41ef6c3370 Mon Sep 17 00:00:00 2001 From: Daniel McAuley Date: Mon, 16 Feb 2026 17:51:31 -0800 Subject: [PATCH 1/4] Stabilize CLI docs and parsing --- .github/workflows/pytest.yml | 6 +-- README.md | 14 ++--- requirements.txt | 1 - setup.cfg | 7 ++- spout_mouse/__init__.py | 18 ------- spout_mouse/analysis.py | 2 + spout_mouse/cli.py | 34 ++++++++++++ spout_mouse/data_loading.py | 36 ++++++++++--- spout_mouse/data_processing.py | 32 ++++++++++-- tests/test_cli.py | 19 +++++++ tests/test_data_loading.py | 17 ++++++ tests/test_data_processing.py | 50 +++++++++++++++++- tests/test_fiber_photometry.py | 96 +++++++++++++++++++++++++++++++++- tests/test_package_exports.py | 10 ++++ 14 files changed, 295 insertions(+), 47 deletions(-) create mode 100644 tests/test_cli.py create mode 100644 tests/test_package_exports.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 191bfa0..4ed342a 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -16,7 +16,7 @@ jobs: - name: Check out repository code uses: actions/checkout@v4 - - name: Set up Python 3.10 + - name: Set up Python 3.11 uses: actions/setup-python@v5 with: python-version: '3.11' @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + pip install -e ".[dev]" - name: Run tests with pytest - run: pytest + run: python -m pytest -q diff --git a/README.md b/README.md index da8a0d6..71c68f7 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,12 @@ cd spout_mouse pip install . ``` +## CLI + +```bash +fp-analysis --help +``` + ## Usage ### Processing Lick Data @@ -55,8 +61,8 @@ with open('path_to_credentials.json') as f: client = authorize_google_sheets(credentials_json) # Fetch experiment data -experiment_records = get_experiment_data(client) -experiment_df = pd.DataFrame(experiment_records) +google_sheet_url = "https://docs.google.com/spreadsheets/d/your-sheet-id" +experiment_df = get_experiment_data(client, google_sheet_url) # Load and process spout metadata spout_names = load_experiment_metadata( @@ -97,7 +103,6 @@ from spout_mouse import ( calculate_mean_zscore, prepare_long_format, calculate_mean_sem_zscores, - plot_zscore_traces, # Assuming this function exists for plotting ) # Define the directory pattern for the data blocks @@ -128,9 +133,6 @@ mean_zscore_long = prepare_long_format(mean_zscore_df, across_days=False) # Calculate mean and SEM of z-scores for plotting mean_sem_zscores = calculate_mean_sem_zscores(mean_zscore_df, across_days=False) -# Generate plots (assuming you have a plotting function) -plot_zscore_traces(mean_sem_zscores) - # Save the processed data mean_sem_zscores.to_csv('mean_sem_zscores.csv', index=False) ``` diff --git a/requirements.txt b/requirements.txt index 59b3b1a..3d9dce9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,3 @@ gspread>=3.6 google-auth>=1.11 pingouin>=0.3.8 tdt>=0.6.6 -pytest-mock>=3.0.0 diff --git a/setup.cfg b/setup.cfg index dee5036..855ad53 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,15 +18,15 @@ classifiers = Topic :: Scientific/Engineering :: Bio-Informatics License :: OSI Approved :: MIT License Programming Language :: Python :: 3 - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Operating System :: OS Independent keywords = neuroscience, fiber photometry, data analysis, biology [options] packages = find: -python_requires = >=3.7 +python_requires = >=3.9 install_requires = pandas>=1.0 numpy>=1.18 @@ -41,7 +41,6 @@ install_requires = google-auth>=1.11 pingouin>=0.3.8 tdt>=0.6.6 - pytest-mock>=3.0.0 include_package_data = True [options.extras_require] diff --git a/spout_mouse/__init__.py b/spout_mouse/__init__.py index f022ef9..8f08710 100644 --- a/spout_mouse/__init__.py +++ b/spout_mouse/__init__.py @@ -91,19 +91,7 @@ 'BASELINE_END', 'SEC_TO_DROP_START', 'SEC_TO_DROP_END', - 'GOOGLE_SHEET_URL', 'MOUSE_GROUPS', - 'downsample_stream', - 'double_exponential', - 'get_bounds', - 'estimate_amplitude', - 'estimate_time_constant', - 'get_initial_params', - 'detrend_signal', - 'build_traces_df', - 'build_spout_df', - 'calculate_zscores', - 'add_auc', 'prepare_fp_dataframe', 'clean_fp_trials', 'truncate_zscore_arrays', @@ -111,10 +99,4 @@ 'calculate_mean_zscore', 'prepare_long_format', 'calculate_mean_sem_zscores', - 'nape_cart_processing', - 'is_first_mouse', - 'is_second_mouse', - 'process_mouse', - 'process_block_path', - 'process_all_blocks', ] diff --git a/spout_mouse/analysis.py b/spout_mouse/analysis.py index fa07921..84100ed 100644 --- a/spout_mouse/analysis.py +++ b/spout_mouse/analysis.py @@ -96,6 +96,8 @@ def aggregate_data_and_calculate_sem(lick_data_spout: pd.DataFrame, combine_days groups += ["day"] def sem_func(arr): + if len(arr) <= 1: + return np.nan return stats.sem(arr, axis=None, ddof=0) lick_data_grouped = lick_data_spout.groupby(groups).agg( diff --git a/spout_mouse/cli.py b/spout_mouse/cli.py index e69de29..1dd68d6 100644 --- a/spout_mouse/cli.py +++ b/spout_mouse/cli.py @@ -0,0 +1,34 @@ +"""Command-line interface for spout_mouse.""" + +import argparse +import sys +from typing import Optional, Sequence + + +def build_parser() -> argparse.ArgumentParser: + """Build the top-level CLI parser.""" + return argparse.ArgumentParser( + prog="fp-analysis", + description=( + "spout_mouse command-line entrypoint. " + "Use this package from Python for full analysis workflows." + ), + ) + + +def main(argv: Optional[Sequence[str]] = None) -> int: + """ + Run the CLI. + + With no arguments, prints help and exits successfully. + """ + parser = build_parser() + if argv is None: + argv = sys.argv[1:] + + if not argv: + parser.print_help() + return 0 + + parser.parse_args(argv) + return 0 diff --git a/spout_mouse/data_loading.py b/spout_mouse/data_loading.py index c09bf49..1266d29 100644 --- a/spout_mouse/data_loading.py +++ b/spout_mouse/data_loading.py @@ -1,8 +1,7 @@ import numpy as np import pandas as pd import os -from typing import List -from tdt import read_block +import re from .config import ( DOWNSAMPLE_RATE, SEC_TO_DROP_START, @@ -60,14 +59,35 @@ def build_spout_df(timestamps: np.ndarray, block_path: str, mouse_id: str) -> pd 'mouse_id': mouse_id }) - # Extract cohort and day from block_path - parts = block_path.split(os.sep) - cohort_part = parts[3] - cohort = int(cohort_part.split()[1]) - day_part = parts[2] - day = int(day_part.split()[1]) + day, cohort = extract_day_and_cohort_from_path(block_path) spout_ext_df['cohort'] = cohort spout_ext_df['day'] = day return spout_ext_df + + +def extract_day_and_cohort_from_path(path: str) -> tuple[int, int]: + """ + Extract day and cohort values from a path. + + The function searches path segments for values formatted as + "day " and "cohort " at any directory depth. + """ + pattern = re.compile(r"^(day|cohort)\s+(\d+)$", re.IGNORECASE) + values: dict[str, int] = {} + + for part in os.path.normpath(path).split(os.sep): + match = pattern.match(part.strip()) + if not match: + continue + key, value = match.group(1).lower(), int(match.group(2)) + if key not in values: + values[key] = value + + missing = [key for key in ("day", "cohort") if key not in values] + if missing: + missing_str = ", ".join(missing) + raise ValueError(f"Could not parse {missing_str} from path: {path}") + + return values["day"], values["cohort"] diff --git a/spout_mouse/data_processing.py b/spout_mouse/data_processing.py index d597614..d8b93d3 100644 --- a/spout_mouse/data_processing.py +++ b/spout_mouse/data_processing.py @@ -14,6 +14,7 @@ LICK_DATA_COLS, MOUSE_GROUPS ) +from .data_loading import extract_day_and_cohort_from_path def extract_zip_files(zip_file_paths: List[str], extract_to: str) -> None: @@ -106,9 +107,10 @@ def process_lick_data(data_directory: str, mouse_ids_to_remove: List[str] = None .ffill() lick_data = lick_data.apply(pd.to_numeric) - lick_data["mouse_id"] = os.path.basename(file_path).split("_")[3].split(".")[0] - lick_data["cohort"] = int(file_path.split("/")[3].split()[1]) - lick_data["day"] = int(file_path.split("/")[2].split()[1]) + lick_data["mouse_id"] = _extract_mouse_id_from_filename(file_path) + day, cohort = extract_day_and_cohort_from_path(file_path) + lick_data["cohort"] = cohort + lick_data["day"] = day lick_data = lick_data.loc[lick_data["time_ms"] > 0] lick_data = lick_data[lick_data["event_tag"].isin(LICK_CODES + [SPOUT_EXT_CODE])] @@ -130,11 +132,33 @@ def process_lick_data(data_directory: str, mouse_ids_to_remove: List[str] = None lick_data_all = pd.concat(lick_data_list, ignore_index=True) if mouse_ids_to_remove: - lick_data_all = lick_data_all[lick_data_all['mouse_id'].isin(mouse_ids_to_remove)] + lick_data_all = lick_data_all[~lick_data_all['mouse_id'].isin(mouse_ids_to_remove)] return lick_data_all +def _extract_mouse_id_from_filename(file_path: str) -> str: + """ + Extract mouse_id from the expected CSV filename format. + + Expected filename contains at least four underscore-separated fields where + the fourth field is the mouse ID, optionally followed by extension. + """ + filename = os.path.basename(file_path) + parts = filename.split("_") + if len(parts) < 4: + raise ValueError( + "Could not parse mouse_id from filename " + f"'{filename}'. Expected at least 4 underscore-separated parts." + ) + + mouse_id = parts[3].split(".")[0] + if not mouse_id: + raise ValueError(f"Could not parse mouse_id from filename '{filename}'.") + + return mouse_id + + def compute_spout_order(lick_data: pd.DataFrame) -> pd.DataFrame: """ Computes the 'spout_order' DataFrame by grouping 'lick_data' by 'cohort', 'day', and 'trial_num', diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..e90dcf4 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,19 @@ +import pytest + +from spout_mouse.cli import main + + +def test_main_no_args_prints_help(capsys): + result = main([]) + captured = capsys.readouterr() + + assert result == 0 + assert "usage:" in captured.out + assert "fp-analysis" in captured.out + + +def test_main_invalid_args_raises_system_exit(): + with pytest.raises(SystemExit) as exc_info: + main(["--not-a-real-flag"]) + + assert exc_info.value.code == 2 diff --git a/tests/test_data_loading.py b/tests/test_data_loading.py index 3dc1145..e2ccd73 100644 --- a/tests/test_data_loading.py +++ b/tests/test_data_loading.py @@ -27,3 +27,20 @@ def test_build_spout_df(): assert all(spout_df['cohort'] == 1) assert all(spout_df['day'] == 2) + +def test_build_spout_df_with_absolute_path(): + timestamps = np.array([10.0, 20.0]) + block_path = "/Users/daniel/data/experiments/cohort 4/raw/day 7/tanks/block-0000-0001" + + spout_df = build_spout_df(timestamps, block_path, "0001") + + assert all(spout_df["cohort"] == 4) + assert all(spout_df["day"] == 7) + + +def test_build_spout_df_missing_day_or_cohort_raises_value_error(): + timestamps = np.array([10.0]) + bad_path = os.path.join("path", "to", "data", "no_day_or_cohort", "block-0000-0001") + + with pytest.raises(ValueError, match="Could not parse"): + build_spout_df(timestamps, bad_path, "0001") diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index d748300..7457568 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -2,7 +2,6 @@ import pandas as pd import os from spout_mouse import data_processing -from unittest import mock import tempfile import shutil @@ -122,3 +121,52 @@ def test_merge_spout_info(): 'group': [None] # Assuming 'mouse1' not in MOUSE_GROUPS }) pd.testing.assert_frame_equal(result, expected) + + +def _write_minimal_lick_csv(file_path: str): + with open(file_path, "w", encoding="utf-8") as handle: + handle.write("event_pos_time\n") + handle.write("127 1\n") + handle.write("13 100\n") + handle.write("331 200\n") + handle.write("331 300\n") + + +def test_process_lick_data_excludes_mouse_ids(): + with tempfile.TemporaryDirectory() as temp_dir: + cohort_dir = os.path.join(temp_dir, "day 1", "cohort 2") + os.makedirs(cohort_dir) + + _write_minimal_lick_csv(os.path.join(cohort_dir, "session_block_trial_0001.csv")) + _write_minimal_lick_csv(os.path.join(cohort_dir, "session_block_trial_0002.csv")) + + result = data_processing.process_lick_data( + temp_dir, + mouse_ids_to_remove=["0002"], + ) + + assert not result.empty + assert (result["mouse_id"] == "0002").sum() == 0 + assert (result["mouse_id"] == "0001").sum() > 0 + + +def test_process_lick_data_invalid_filename_raises_value_error(): + with tempfile.TemporaryDirectory() as temp_dir: + cohort_dir = os.path.join(temp_dir, "day 1", "cohort 2") + os.makedirs(cohort_dir) + + _write_minimal_lick_csv(os.path.join(cohort_dir, "bad.csv")) + + with pytest.raises(ValueError, match="Could not parse mouse_id"): + data_processing.process_lick_data(temp_dir) + + +def test_process_lick_data_invalid_path_raises_value_error(): + with tempfile.TemporaryDirectory() as temp_dir: + invalid_dir = os.path.join(temp_dir, "misc") + os.makedirs(invalid_dir) + + _write_minimal_lick_csv(os.path.join(invalid_dir, "session_block_trial_0001.csv")) + + with pytest.raises(ValueError, match="Could not parse"): + data_processing.process_lick_data(temp_dir) diff --git a/tests/test_fiber_photometry.py b/tests/test_fiber_photometry.py index 7beda55..bb035d1 100644 --- a/tests/test_fiber_photometry.py +++ b/tests/test_fiber_photometry.py @@ -1,7 +1,8 @@ import pytest import numpy as np import pandas as pd -from unittest.mock import patch, MagicMock +from types import SimpleNamespace +from unittest.mock import patch, call from spout_mouse.fiber_photometry import ( nape_cart_processing, is_first_mouse, @@ -11,7 +12,6 @@ process_all_blocks ) from spout_mouse.config import DOWNSAMPLE_RATE -import os import warnings from scipy.optimize import OptimizeWarning @@ -78,3 +78,95 @@ def test_process_mouse( mock_add_auc.assert_called_once_with( mock_calculate_zscores.return_value, sample_rate, DOWNSAMPLE_RATE ) + + +def _mock_tdt_data(): + return SimpleNamespace( + streams=SimpleNamespace( + _470A=SimpleNamespace(fs=470.0, data=np.array([0.47, 0.48])), + _465A=SimpleNamespace(fs=465.0, data=np.array([0.11, 0.12])), + _465C=SimpleNamespace(fs=466.0, data=np.array([0.21, 0.22])), + ), + epocs=SimpleNamespace( + PtC0=SimpleNamespace(onset=np.array([1.0, 2.0])), + PtC2=SimpleNamespace(onset=np.array([3.0, 4.0])), + ), + ) + + +@patch("spout_mouse.fiber_photometry.process_mouse") +@patch("spout_mouse.fiber_photometry.downsample_stream") +@patch("spout_mouse.fiber_photometry.read_block") +def test_process_block_path_nape_cart_branch(mock_read_block, mock_downsample_stream, mock_process_mouse): + block_path = "path/to/day 1/cohort 1/0001-0002-0003" + mock_read_block.return_value = _mock_tdt_data() + mock_downsample_stream.return_value = np.array([0.9, 0.8]) + expected_df = pd.DataFrame({"mouse_id": ["0001"], "trial_num": [1]}) + mock_process_mouse.return_value = expected_df + + result = process_block_path(block_path) + + pd.testing.assert_frame_equal(result, expected_df) + mock_downsample_stream.assert_called_once_with(mock_read_block.return_value.streams._470A.data) + assert mock_process_mouse.call_count == 1 + args = mock_process_mouse.call_args.args + np.testing.assert_array_equal(args[0], np.array([0.9, 0.8])) + assert args[1] == mock_read_block.return_value.streams._470A.fs + np.testing.assert_array_equal(args[2], mock_read_block.return_value.epocs.PtC0.onset) + assert args[3] == block_path + assert args[4] == "0001" + + +@patch("spout_mouse.fiber_photometry.process_mouse") +@patch("spout_mouse.fiber_photometry.downsample_stream") +@patch("spout_mouse.fiber_photometry.read_block") +def test_process_block_path_dual_mouse_branch(mock_read_block, mock_downsample_stream, mock_process_mouse): + block_path = "path/to/day 1/cohort 1/0001-0002-0003-0004" + mock_read_block.return_value = _mock_tdt_data() + mock_downsample_stream.side_effect = [np.array([1.0]), np.array([2.0])] + mock_process_mouse.side_effect = [ + pd.DataFrame({"mouse_id": ["0001"], "trial_num": [1]}), + pd.DataFrame({"mouse_id": ["0002"], "trial_num": [1]}), + ] + + result = process_block_path(block_path) + + assert len(result) == 2 + first_args = mock_process_mouse.call_args_list[0].args + second_args = mock_process_mouse.call_args_list[1].args + + np.testing.assert_array_equal(first_args[0], np.array([1.0])) + assert first_args[1] == mock_read_block.return_value.streams._465A.fs + np.testing.assert_array_equal(first_args[2], mock_read_block.return_value.epocs.PtC0.onset) + assert first_args[3] == block_path + assert first_args[4] == "0001" + + np.testing.assert_array_equal(second_args[0], np.array([2.0])) + assert second_args[1] == mock_read_block.return_value.streams._465C.fs + np.testing.assert_array_equal(second_args[2], mock_read_block.return_value.epocs.PtC2.onset) + assert second_args[3] == block_path + assert second_args[4] == "0002" + + +@patch("spout_mouse.fiber_photometry.process_block_path") +@patch("spout_mouse.fiber_photometry.glob.glob") +def test_process_all_blocks_concatenates_all_results(mock_glob, mock_process_block_path): + mock_glob.return_value = [ + "path/to/day 1/cohort 1/block-0000-0001", + "path/to/day 2/cohort 1/block-0002-0003", + ] + mock_process_block_path.side_effect = [ + pd.DataFrame({"mouse_id": ["0000"], "trial_num": [1]}), + pd.DataFrame({"mouse_id": ["0002"], "trial_num": [1]}), + ] + + result = process_all_blocks("path/to/*") + + assert len(result) == 2 + assert result["mouse_id"].tolist() == ["0000", "0002"] + mock_process_block_path.assert_has_calls( + [ + call("path/to/day 1/cohort 1/block-0000-0001"), + call("path/to/day 2/cohort 1/block-0002-0003"), + ] + ) diff --git a/tests/test_package_exports.py b/tests/test_package_exports.py new file mode 100644 index 0000000..a8d1f8f --- /dev/null +++ b/tests/test_package_exports.py @@ -0,0 +1,10 @@ +import spout_mouse + + +def test_all_exports_exist(): + missing = [name for name in spout_mouse.__all__ if not hasattr(spout_mouse, name)] + assert missing == [] + + +def test_all_exports_are_unique(): + assert len(spout_mouse.__all__) == len(set(spout_mouse.__all__)) From 8b4d9ff76d1fd625c2c6622c6f7cdc4d807117b8 Mon Sep 17 00:00:00 2001 From: Daniel McAuley Date: Mon, 16 Feb 2026 20:05:13 -0800 Subject: [PATCH 2/4] Add CLI and parsing fixes --- tests/test_analysis.py | 5 ++--- tests/test_data_processing.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index cdc8016..520b1d8 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -2,7 +2,7 @@ import pandas as pd import numpy as np from spout_mouse import analysis -from spout_mouse.config import DOWNSAMPLE_RATE, MOUSE_GROUPS +from spout_mouse.config import DOWNSAMPLE_RATE from unittest.mock import patch @@ -125,14 +125,13 @@ def test_add_auc(self): self.assertEqual(len(result_df), 2) self.assertIsInstance(result_df['auc'].iloc[0], float) - @patch('spout_mouse.config.MOUSE_GROUPS', {'1274': 'sgRosa26'}) def test_prepare_fp_dataframe(self): excluded_mice = ["0037", "9694", "1228", "0036", "0039", "9692", "0061"] prepared_df = analysis.prepare_fp_dataframe(self.fp_df, excluded_mice) # Check that mouse_id is string self.assertTrue(prepared_df['mouse_id'].dtype == object) # Check that 'group' is mapped correctly - expected_groups = prepared_df['mouse_id'].map(MOUSE_GROUPS) + expected_groups = prepared_df['mouse_id'].map(analysis.MOUSE_GROUPS) pd.testing.assert_series_equal(prepared_df['group'], expected_groups, check_names=False) # Check that excluded mice are removed self.assertFalse(prepared_df['mouse_id'].isin(excluded_mice).any()) diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index 7457568..81148e8 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -1,5 +1,6 @@ import pytest import pandas as pd +import numpy as np import os from spout_mouse import data_processing import tempfile @@ -118,7 +119,7 @@ def test_merge_spout_info(): 'lick_count': [5], 'lick_count_hz': [25.0], 'spout_name': ['water'], - 'group': [None] # Assuming 'mouse1' not in MOUSE_GROUPS + 'group': pd.Series([np.nan], dtype=object) }) pd.testing.assert_frame_equal(result, expected) From 163c24cb8eef5eefc11cd2f65ac2d966db8b3fb5 Mon Sep 17 00:00:00 2001 From: dmca Date: Mon, 16 Feb 2026 20:08:25 -0800 Subject: [PATCH 3/4] Fix pytest failures around dtype handling and mocker fixture --- spout_mouse/analysis.py | 2 +- spout_mouse/data_processing.py | 3 ++- tests/conftest.py | 26 ++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 tests/conftest.py diff --git a/spout_mouse/analysis.py b/spout_mouse/analysis.py index 84100ed..ac450e5 100644 --- a/spout_mouse/analysis.py +++ b/spout_mouse/analysis.py @@ -189,7 +189,7 @@ def prepare_fp_dataframe(fp_df: pd.DataFrame, excluded_mice: list[str] = None) - pd.DataFrame: The prepared DataFrame. """ # Convert mouse_id from int to str - fp_df["mouse_id"] = fp_df["mouse_id"].astype(str) + fp_df["mouse_id"] = fp_df["mouse_id"].astype(str).astype(object) # Map mouse_id to group fp_df["group"] = fp_df["mouse_id"].map(MOUSE_GROUPS) diff --git a/spout_mouse/data_processing.py b/spout_mouse/data_processing.py index d8b93d3..b72654f 100644 --- a/spout_mouse/data_processing.py +++ b/spout_mouse/data_processing.py @@ -259,5 +259,6 @@ def merge_spout_info(lick_rate: pd.DataFrame, spout_names: pd.DataFrame) -> pd.D pd.DataFrame: Merged DataFrame with spout and group information. """ merged_data = lick_rate.merge(spout_names, on=['cohort', 'day', 'spout_id'], how='left') - merged_data['group'] = merged_data['mouse_id'].map(MOUSE_GROUPS) + merged_data['group'] = merged_data['mouse_id'].map(MOUSE_GROUPS).astype(object) + merged_data['group'] = merged_data['group'].replace({np.nan: None}) return merged_data diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..256cc0f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,26 @@ +from unittest.mock import patch + +import pytest + + +class _Mocker: + def __init__(self): + self._patchers = [] + + def patch(self, target, *args, **kwargs): + patcher = patch(target, *args, **kwargs) + self._patchers.append(patcher) + return patcher.start() + + def stopall(self): + while self._patchers: + self._patchers.pop().stop() + + +@pytest.fixture +def mocker(): + mocker_instance = _Mocker() + try: + yield mocker_instance + finally: + mocker_instance.stopall() From c608eeb6c7f7f629a8811a9e68aaf3eb0bc6b190 Mon Sep 17 00:00:00 2001 From: dmca Date: Mon, 16 Feb 2026 20:13:17 -0800 Subject: [PATCH 4/4] Fix merge_spout_info missing group value handling --- spout_mouse/data_processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spout_mouse/data_processing.py b/spout_mouse/data_processing.py index b72654f..49aa116 100644 --- a/spout_mouse/data_processing.py +++ b/spout_mouse/data_processing.py @@ -260,5 +260,4 @@ def merge_spout_info(lick_rate: pd.DataFrame, spout_names: pd.DataFrame) -> pd.D """ merged_data = lick_rate.merge(spout_names, on=['cohort', 'day', 'spout_id'], how='left') merged_data['group'] = merged_data['mouse_id'].map(MOUSE_GROUPS).astype(object) - merged_data['group'] = merged_data['group'].replace({np.nan: None}) return merged_data