Skip to content

Commit a523381

Browse files
committed
Extend to support column name
1 parent fdcd140 commit a523381

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

python/pyspark/sql/classic/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,8 +1638,10 @@ def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame:
16381638
self.sparkSession,
16391639
)
16401640

1641-
def withColumn(self, colName: str, col: Column) -> ParentDataFrame:
1642-
if not isinstance(col, Column):
1641+
def withColumn(self, colName: str, col: "ColumnOrName") -> ParentDataFrame:
1642+
if isinstance(col, str):
1643+
col = F.col(col)
1644+
elif not isinstance(col, Column):
16431645
raise PySparkTypeError(
16441646
errorClass="NOT_COLUMN",
16451647
messageParameters={"arg_name": "col", "arg_type": type(col).__name__},

python/pyspark/sql/connect/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -932,8 +932,10 @@ def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame:
932932
session=self._session,
933933
)
934934

935-
def withColumn(self, colName: str, col: Column) -> ParentDataFrame:
936-
if not isinstance(col, Column):
935+
def withColumn(self, colName: str, col: "ColumnOrName") -> ParentDataFrame:
936+
if isinstance(col, str):
937+
col = F.col(col)
938+
elif not isinstance(col, Column):
937939
raise PySparkTypeError(
938940
errorClass="NOT_COLUMN",
939941
messageParameters={"arg_name": "col", "arg_type": type(col).__name__},

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5651,7 +5651,7 @@ def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame":
56515651
...
56525652

56535653
@dispatch_df_method
5654-
def withColumn(self, colName: str, col: Column) -> "DataFrame":
5654+
def withColumn(self, colName: str, col: "ColumnOrName") -> "DataFrame":
56555655
"""
56565656
Returns a new :class:`DataFrame` by adding a column or replacing the
56575657
existing column that has the same name.
@@ -5668,8 +5668,8 @@ def withColumn(self, colName: str, col: Column) -> "DataFrame":
56685668
----------
56695669
colName : str
56705670
string, name of the new column.
5671-
col : :class:`Column`
5672-
a :class:`Column` expression for the new column.
5671+
cols: str or :class:`Column`
5672+
A name of a column, or a :class:`Column` expression for the new column.
56735673
56745674
Returns
56755675
-------

python/pyspark/sql/tests/test_dataframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,10 @@ def test_with_column_with_existing_name(self):
413413
keys = self.df.withColumn("key", self.df.key).select("key").collect()
414414
self.assertEqual([r.key for r in keys], list(range(100)))
415415

416+
def test_with_column_with_column_name(self):
417+
keys = self.df.withColumn("key", "key").select("key").collect()
418+
self.assertEqual([r.key for r in keys], list(range(100)))
419+
416420
# regression test for SPARK-10417
417421
def test_column_iterator(self):
418422
def foo():

0 commit comments

Comments
 (0)