From e451bf708137c28dcac7b0537cc0c23ddb2fcb93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Mon, 20 Oct 2025 19:39:40 +0100 Subject: [PATCH] with_column supports SQL expression --- python/datafusion/dataframe.py | 7 +++++-- python/tests/test_dataframe.py | 22 +++++++++++++++------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c21a3bc79..343e40275 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -521,11 +521,12 @@ def parse_sql_expr(self, expr: str) -> Expr: """ return Expr(self.df.parse_sql_expr(expr)) - def with_column(self, name: str, expr: Expr) -> DataFrame: + def with_column(self, name: str, expr: Expr | str) -> DataFrame: """Add an additional column to the DataFrame. The ``expr`` must be an :class:`~datafusion.expr.Expr` constructed with - :func:`datafusion.col` or :func:`datafusion.lit`. + :func:`datafusion.col` or :func:`datafusion.lit`, or a SQL expression + string that will be parsed against the DataFrame schema. Example:: @@ -539,6 +540,8 @@ def with_column(self, name: str, expr: Expr) -> DataFrame: Returns: DataFrame with the new column. """ + expr = self.parse_sql_expr(expr) if isinstance(expr, str) else expr + return DataFrame(self.df.with_column(name, ensure_expr(expr))) def with_columns( diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index aceebadb4..1aa3c1e43 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -477,8 +477,8 @@ def test_tail(df): assert result.column(2) == pa.array([8]) -def test_with_column(df): - df = df.with_column("c", column("a") + column("b")) +def test_with_column_sql_expression(df): + df = df.with_column("c", "a + b") # execute and collect the first (and only) batch result = df.collect()[0] @@ -492,11 +492,19 @@ def test_with_column(df): assert result.column(2) == pa.array([5, 7, 9]) -def test_with_column_invalid_expr(df): - with pytest.raises( - TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)" - ): - df.with_column("c", "a") +def test_with_column(df): + df = df.with_column("c", column("a") + column("b")) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + assert result.schema.field(0).name == "a" + assert result.schema.field(1).name == "b" + assert result.schema.field(2).name == "c" + + assert result.column(0) == pa.array([1, 2, 3]) + assert result.column(1) == pa.array([4, 5, 6]) + assert result.column(2) == pa.array([5, 7, 9]) def test_with_columns(df):