Skip to content

Commit 897e499

Browse files
author
Itamar Ravid
committed
Add a typed col function for creating column references
Resolves #186.
1 parent 68aa838 commit 897e499

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

dataset/src/main/scala/frameless/functions/package.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package frameless
22

33
import org.apache.spark.sql.catalyst.ScalaReflection
44
import org.apache.spark.sql.catalyst.expressions.Literal
5+
import org.apache.spark.sql.functions.{ col => sparkCol }
6+
import shapeless.Witness
57

68
package object functions extends Udf with UnaryFunctions {
79
object aggregate extends AggregateFunctions
@@ -17,4 +19,12 @@ package object functions extends Udf with UnaryFunctions {
1719
new TypedColumn(expr)
1820
}
1921
}
22+
23+
def col[T, A](column: Witness.Lt[Symbol])(
24+
implicit
25+
exists: TypedColumn.Exists[T, column.T, A],
26+
encoder: TypedEncoder[A]): TypedColumn[T, A] = {
27+
val untypedExpr = sparkCol(column.value.name).as[A](TypedExpressionEncoder[A])
28+
new TypedColumn[T, A](untypedExpr)
29+
}
2030
}

dataset/src/test/scala/frameless/SelectTests.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ class SelectTests extends TypedDatasetSuite {
1818
val A = dataset.col[A]('a)
1919

2020
val dataset2 = dataset.select(A).collect().run().toVector
21+
val symDataset2 = dataset.select(functions.col('a)).collect().run().toVector
2122
val data2 = data.map { case X4(a, _, _, _) => a }
2223

23-
dataset2 ?= data2
24+
(dataset2 ?= data2) && (symDataset2 ?= data2)
2425
}
2526

2627
check(forAll(prop[Int, Int, Int, Int] _))

0 commit comments

Comments
 (0)