Skip to content

Commit

Permalink
don't drop null values when dropping invalid rows in polars and nulla…
Browse files Browse the repository at this point in the history
…ble=True

Signed-off-by: Jacob Baldwin <[email protected]>
  • Loading branch information
baldwinj30 committed Dec 31, 2024
1 parent 1b8e925 commit 8e2d829
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
4 changes: 3 additions & 1 deletion pandera/backends/polars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def drop_invalid_rows(
valid_rows = check_outputs.select(
valid_rows=pl.fold(
acc=pl.lit(True),
function=lambda acc, x: acc & x,
# if nullable=True for a column, the check_outputs values will be null for that row
# if the row value is null
function=lambda acc, x: acc & (x | x.is_null()),
exprs=pl.col(pl.Boolean),
)
)["valid_rows"]
Expand Down
47 changes: 43 additions & 4 deletions tests/polars/test_polars_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,20 @@ def custom_check(data: PolarsData):
"column_mod,filter_expr",
[
({"int_col": pl.Series([-1, 1, 1])}, pl.col("int_col").ge(0)),
({"string_col": pl.Series([*"013"])}, pl.col("string_col").ne("d")),
({"string_col": pl.Series([*"013"])}, pl.col("string_col").ne("3")),
(
{
"int_col": pl.Series([-1, 1, 1]),
"string_col": pl.Series([*"013"]),
},
pl.col("int_col").ge(0) & pl.col("string_col").ne("d"),
pl.col("int_col").ge(0) & pl.col("string_col").ne("3"),
),
({"int_col": pl.lit(-1)}, pl.col("int_col").ge(0)),
({"int_col": pl.lit("d")}, pl.col("string_col").ne("d")),
({"string_col": pl.lit("d")}, pl.col("string_col").ne("d")),
(
{"int_col": pl.Series([None, 1, 1])},
pl.col("int_col").is_not_null(),
),
],
)
@pytest.mark.parametrize("lazy", [False, True])
Expand All @@ -334,7 +338,7 @@ def test_drop_invalid_rows(
ldf_schema_with_check,
):
ldf_schema_with_check.drop_invalid_rows = True
modified_data = ldf_basic.with_columns(column_mod)
modified_data = ldf_basic.with_columns(**column_mod)
if lazy:
validated_data = modified_data.pipe(
ldf_schema_with_check.validate,
Expand All @@ -350,6 +354,41 @@ def test_drop_invalid_rows(
)


@pytest.mark.parametrize(
"schema_column_updates,column_mod,filter_expr",
[
(
{"int_col": {"nullable": True}},
{"int_col": pl.Series([None, -1, 1])},
pl.col("int_col").is_null() | (pl.col("int_col") > 0),
),
(
{"int_col": {"nullable": True, "checks": [pa.Check.isin([1])]}},
{"int_col": pl.Series([None, 1, 2])},
pl.col("int_col").is_null() | (pl.col("int_col") == 1),
),
],
)
def test_drop_invalid_rows_nullable(
schema_column_updates,
column_mod,
filter_expr,
ldf_basic,
ldf_schema_with_check,
):
ldf_schema_with_check.drop_invalid_rows = True
nullable_schema = ldf_schema_with_check.update_columns(
schema_column_updates
)
modified_data = ldf_basic.with_columns(**column_mod)
validated_data = modified_data.pipe(
nullable_schema.validate,
lazy=True,
)
expected_valid_data = modified_data.filter(filter_expr)
assert validated_data.collect().equals(expected_valid_data.collect())


def test_set_defaults(ldf_basic, ldf_schema_basic):
ldf_schema_basic.columns["int_col"].default = 1
ldf_schema_basic.columns["string_col"].default = "a"
Expand Down

0 comments on commit 8e2d829

Please sign in to comment.