Skip to content

Commit

Permalink
move the fix for keeping nulls in polars when nullable=True to the ch…
Browse files Browse the repository at this point in the history
…eck backend

Signed-off-by: Jacob Baldwin <[email protected]>
  • Loading branch information
baldwinj30 committed Jan 3, 2025
1 parent 8e2d829 commit 17a1967
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
4 changes: 1 addition & 3 deletions pandera/backends/polars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,7 @@ def drop_invalid_rows(
valid_rows = check_outputs.select(
valid_rows=pl.fold(
acc=pl.lit(True),
# if nullable=True for a column, the check_outputs values will be null for that row
# if the row value is null
function=lambda acc, x: acc & (x | x.is_null()),
function=lambda acc, x: acc & x,
exprs=pl.col(pl.Boolean),
)
)["valid_rows"]
Expand Down
4 changes: 4 additions & 0 deletions pandera/backends/polars/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def postprocess_lazyframe_output(
) -> CheckResult:
"""Postprocesses the result of applying the check function."""
results = pl.LazyFrame(check_output.collect())
if self.check.ignore_na:
results = results.with_columns(

Check warning on line 98 in pandera/backends/polars/checks.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/polars/checks.py#L97-L98

Added lines #L97 - L98 were not covered by tests
pl.col(CHECK_OUTPUT_KEY) | pl.col(CHECK_OUTPUT_KEY).is_null()
)
passed = results.select([pl.col(CHECK_OUTPUT_KEY).all()])
failure_cases = pl.concat(
[check_obj.lazyframe, results], how="horizontal"
Expand Down
21 changes: 17 additions & 4 deletions tests/polars/test_polars_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,32 @@ def _column_check_fn_scalar_out(data: pa.PolarsData) -> pl.LazyFrame:


@pytest.mark.parametrize(
"check_fn, invalid_data, expected_output",
"check_fn, invalid_data, expected_output, ignore_na",
[
[_column_check_fn_df_out, [-1, 2, 3, -2], [False, True, True, False]],
[_column_check_fn_scalar_out, [-1, 2, 3, -2], [False]],
[
_column_check_fn_df_out,
[-1, 2, 3, -2],
[False, True, True, False],
False,
],
[_column_check_fn_scalar_out, [-1, 2, 3, -2], [False], False],
[
_column_check_fn_df_out,
[-1, 2, 3, None],
[False, True, True, True],
True,
],
[_column_check_fn_scalar_out, [-1, 2, 3, None], [False], True],
],
)
def test_polars_column_check(
column_lf,
check_fn,
invalid_data,
expected_output,
ignore_na,
):
check = pa.Check(check_fn)
check = pa.Check(check_fn, ignore_na=ignore_na)
check_result = check(column_lf, column="col")
assert check_result.check_passed.collect().item()

Expand Down

0 comments on commit 17a1967

Please sign in to comment.