diff --git a/pandera/backends/polars/base.py b/pandera/backends/polars/base.py index ef7f7da13..2f5972dac 100644 --- a/pandera/backends/polars/base.py +++ b/pandera/backends/polars/base.py @@ -252,7 +252,9 @@ def drop_invalid_rows( valid_rows = check_outputs.select( valid_rows=pl.fold( acc=pl.lit(True), - function=lambda acc, x: acc & x, + # if nullable=True for a column, the check_outputs values will be null for that row + # if the row value is null + function=lambda acc, x: acc & (x | x.is_null()), exprs=pl.col(pl.Boolean), ) )["valid_rows"] diff --git a/tests/polars/test_polars_container.py b/tests/polars/test_polars_container.py index c44a9448a..edcaf10d7 100644 --- a/tests/polars/test_polars_container.py +++ b/tests/polars/test_polars_container.py @@ -313,16 +313,20 @@ def custom_check(data: PolarsData): "column_mod,filter_expr", [ ({"int_col": pl.Series([-1, 1, 1])}, pl.col("int_col").ge(0)), - ({"string_col": pl.Series([*"013"])}, pl.col("string_col").ne("d")), + ({"string_col": pl.Series([*"013"])}, pl.col("string_col").ne("3")), ( { "int_col": pl.Series([-1, 1, 1]), "string_col": pl.Series([*"013"]), }, - pl.col("int_col").ge(0) & pl.col("string_col").ne("d"), + pl.col("int_col").ge(0) & pl.col("string_col").ne("3"), ), ({"int_col": pl.lit(-1)}, pl.col("int_col").ge(0)), - ({"int_col": pl.lit("d")}, pl.col("string_col").ne("d")), + ({"string_col": pl.lit("d")}, pl.col("string_col").ne("d")), + ( + {"int_col": pl.Series([None, 1, 1])}, + pl.col("int_col").is_not_null(), + ), ], ) @pytest.mark.parametrize("lazy", [False, True]) @@ -334,7 +338,7 @@ def test_drop_invalid_rows( ldf_schema_with_check, ): ldf_schema_with_check.drop_invalid_rows = True - modified_data = ldf_basic.with_columns(column_mod) + modified_data = ldf_basic.with_columns(**column_mod) if lazy: validated_data = modified_data.pipe( ldf_schema_with_check.validate, @@ -350,6 +354,41 @@ def test_drop_invalid_rows( ) +@pytest.mark.parametrize( + "schema_column_updates,column_mod,filter_expr", + [ + ( + {"int_col": {"nullable": True}}, + {"int_col": pl.Series([None, -1, 1])}, + pl.col("int_col").is_null() | (pl.col("int_col") > 0), + ), + ( + {"int_col": {"nullable": True, "checks": [pa.Check.isin([1])]}}, + {"int_col": pl.Series([None, 1, 2])}, + pl.col("int_col").is_null() | (pl.col("int_col") == 1), + ), + ], +) +def test_drop_invalid_rows_nullable( + schema_column_updates, + column_mod, + filter_expr, + ldf_basic, + ldf_schema_with_check, +): + ldf_schema_with_check.drop_invalid_rows = True + nullable_schema = ldf_schema_with_check.update_columns( + schema_column_updates + ) + modified_data = ldf_basic.with_columns(**column_mod) + validated_data = modified_data.pipe( + nullable_schema.validate, + lazy=True, + ) + expected_valid_data = modified_data.filter(filter_expr) + assert validated_data.collect().equals(expected_valid_data.collect()) + + def test_set_defaults(ldf_basic, ldf_schema_basic): ldf_schema_basic.columns["int_col"].default = 1 ldf_schema_basic.columns["string_col"].default = "a"