Skip to content

Commit

Permalink
Add cohorts snapshot tests with syrupy (#379)
Browse files Browse the repository at this point in the history
* Add cohorts snapshot tests with syrupy

* Fix.

* fix again

* Rework CI

* [revery]

* improve

* fix mypy?

* Revert "[revery]"

This reverts commit 7664e5e.

* Try again

* fix mypy
  • Loading branch information
dcherian authored Aug 2, 2024
1 parent 4d03d70 commit cb3fc1f
Show file tree
Hide file tree
Showing 13 changed files with 22,362 additions and 50 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ jobs:
- name: Run mypy
run: |
python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report
mkdir .mypy_cache
python -m mypy --install-types --non-interactive --cache-dir=.mypy_cache/ --cobertura-xml-report mypy_report
- name: Upload mypy coverage to Codecov
uses: codecov/[email protected]
Expand Down
54 changes: 10 additions & 44 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ concurrency:

jobs:
test:
name: Test (${{ matrix.python-version }}, ${{ matrix.os }})
name: Test (${{matrix.env}}, ${{ matrix.python-version }}, ${{ matrix.os }})
runs-on: ${{ matrix.os }}
defaults:
run:
Expand All @@ -25,10 +25,18 @@ jobs:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
env: ["environment"]
python-version: ["3.9", "3.12"]
include:
- os: "windows-latest"
env: "environment"
python-version: "3.12"
- os: "ubuntu-latest"
env: "no-dask" # "no-xarray", "no-numba"
python-version: "3.12"
- os: "ubuntu-latest"
env: "minimal-requirements"
python-version: "3.9"
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -39,7 +47,7 @@ jobs:
- name: Set up conda environment
uses: mamba-org/setup-micromamba@v1
with:
environment-file: ci/environment.yml
environment-file: ci/${{ matrix.env }}.yml
environment-name: flox-tests
init-shell: bash
cache-environment: true
Expand Down Expand Up @@ -81,48 +89,6 @@ jobs:
path: .hypothesis/
key: cache-hypothesis-${{ runner.os }}-${{ matrix.python-version }}-${{ github.run_id }}

optional-deps:
name: ${{ matrix.env }}
runs-on: "ubuntu-latest"
defaults:
run:
shell: bash -l {0}
strategy:
fail-fast: false
matrix:
python-version: ["3.12"]
env: ["no-dask"] # "no-xarray", "no-numba"
include:
- env: "minimal-requirements"
python-version: "3.9"
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for all branches and tags.
- name: Set up conda environment
uses: mamba-org/setup-micromamba@v1
with:
environment-file: ci/${{ matrix.env }}.yml
environment-name: flox-tests
init-shell: bash
cache-environment: true
create-args: |
python=${{ matrix.python-version }}
- name: Install flox
run: |
python -m pip install --no-deps -e .
- name: Run tests
run: |
python -m pytest -n auto --cov=./ --cov-report=xml
- name: Upload code coverage to Codecov
uses: codecov/[email protected]
with:
file: ./coverage.xml
flags: unittests
env_vars: RUNNER_OS
name: codecov-umbrella
fail_ci_if_error: false

xarray-groupby:
name: xarray-groupby
runs-on: ubuntu-latest
Expand Down
Empty file added asv_bench/__init__.py
Empty file.
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- pytest-cov
- pytest-pretty
- pytest-xdist
- syrupy
- xarray
- pre-commit
- numpy_groupies>=0.9.19
Expand Down
1 change: 1 addition & 0 deletions ci/minimal-requirements.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- pytest-cov
- pytest-pretty
- pytest-xdist
- syrupy
- numpy==1.22
- scipy==1.9.0
- numpy_groupies==0.9.19
Expand Down
1 change: 1 addition & 0 deletions ci/no-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- pytest-cov
- pytest-pretty
- pytest-xdist
- syrupy
- xarray
- numpydoc
- pre-commit
Expand Down
1 change: 1 addition & 0 deletions ci/no-numba.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- pytest-cov
- pytest-pretty
- pytest-xdist
- syrupy
- xarray
- pre-commit
- numpy_groupies>=0.9.19
Expand Down
2 changes: 2 additions & 0 deletions ci/no-xarray.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ channels:
- conda-forge
dependencies:
- codecov
- syrupy
- pandas
- numpy>=1.22
- scipy
Expand All @@ -11,6 +12,7 @@ dependencies:
- pytest-cov
- pytest-pretty
- pytest-xdist
- syrupy
- dask-core
- numpydoc
- pre-commit
Expand Down
1 change: 1 addition & 0 deletions ci/upstream-dev-env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
# - scipy
- pytest-pretty
- pytest-xdist
- syrupy
- pip
# for cftime
- cython>=0.29.20
Expand Down
10 changes: 6 additions & 4 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,9 @@ def find_group_cohorts(
chunks_per_label = chunks_per_label[present_labels_mask]

label_chunks = {
present_labels[idx]: bitmask.indices[slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])]
present_labels[idx].item(): bitmask.indices[
slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])
]
for idx in range(bitmask.shape[LABEL_AXIS])
}

Expand Down Expand Up @@ -485,7 +487,7 @@ def invert(x) -> tuple[np.ndarray, ...]:

# Iterate over labels, beginning with those with most chunks
logger.debug("find_group_cohorts: merging cohorts")
order = np.argsort(containment.sum(axis=LABEL_AXIS))[::-1]
order = np.argsort(containment.sum(axis=LABEL_AXIS), kind="stable")[::-1]
merged_cohorts = {}
merged_keys = set()
# TODO: we can optimize this to loop over chunk_cohorts instead
Expand All @@ -495,11 +497,11 @@ def invert(x) -> tuple[np.ndarray, ...]:
slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
]
cohort_ = present_labels[cohidx]
cohort = [elem for elem in cohort_ if elem not in merged_keys]
cohort = [elem.item() for elem in cohort_ if elem not in merged_keys]
if not cohort:
continue
merged_keys.update(cohort)
allchunks = (label_chunks[member] for member in cohort)
allchunks = (label_chunks[member].tolist() for member in cohort)
chunk = tuple(set(itertools.chain(*allchunks)))
merged_cohorts[chunk] = cohort

Expand Down
Loading

0 comments on commit cb3fc1f

Please sign in to comment.