Skip to content

Commit 51b9d71

Browse files
committed
Add a test case for the new TypedRow encoder
implemented the proposal
1 parent 75fc432 commit 51b9d71

File tree

7 files changed

+209
-43
lines changed

7 files changed

+209
-43
lines changed

dataset/src/main/scala/frameless/RecordEncoder.scala

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -119,20 +119,19 @@ object DropUnitValues {
119119
}
120120
}
121121

122-
class RecordEncoder[F, G <: HList, H <: HList]
122+
abstract class RecordEncoder[F, G <: HList, H <: HList]
123123
(implicit
124-
i0: LabelledGeneric.Aux[F, G],
125-
i1: DropUnitValues.Aux[G, H],
126-
i2: IsHCons[H],
127-
fields: Lazy[RecordEncoderFields[H]],
128-
newInstanceExprs: Lazy[NewInstanceExprs[G]],
124+
stage1: RecordEncoderStage1[G, H],
129125
classTag: ClassTag[F]
130126
) extends TypedEncoder[F] {
127+
128+
import stage1._
129+
131130
def nullable: Boolean = false
132131

133-
def jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
132+
lazy val jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
134133

135-
def catalystRepr: DataType = {
134+
lazy val catalystRepr: DataType = {
136135
val structFields = fields.value.value.map { field =>
137136
StructField(
138137
name = field.name,
@@ -145,41 +144,94 @@ class RecordEncoder[F, G <: HList, H <: HList]
145144
StructType(structFields)
146145
}
147146

147+
}
148+
149+
object RecordEncoder {
150+
151+
case class ForGeneric[F, G <: HList, H <: HList](
152+
)(implicit
153+
stage1: RecordEncoderStage1[G, H],
154+
classTag: ClassTag[F])
155+
extends RecordEncoder[F, G, H] {
156+
157+
import stage1._
158+
148159
def toCatalyst(path: Expression): Expression = {
149-
val nameExprs = fields.value.value.map { field =>
150-
Literal(field.name)
151-
}
152160

153161
val valueExprs = fields.value.value.map { field =>
154162
val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil)
155163
field.encoder.toCatalyst(fieldPath)
156164
}
157165

158-
// the way exprs are encoded in CreateNamedStruct
159-
val exprs = nameExprs.zip(valueExprs).flatMap {
160-
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
161-
}
166+
val createExpr = stage1.cellsToCatalyst(valueExprs)
162167

163-
val createExpr = CreateNamedStruct(exprs)
164168
val nullExpr = Literal.create(null, createExpr.dataType)
165169

166170
If(IsNull(path), nullExpr, createExpr)
167171
}
168172

169173
def fromCatalyst(path: Expression): Expression = {
170-
val exprs = fields.value.value.map { field =>
171-
field.encoder.fromCatalyst(
172-
GetStructField(path, field.ordinal, Some(field.name)))
173-
}
174174

175-
val newArgs = newInstanceExprs.value.from(exprs)
175+
val newArgs = stage1.fromCatalystToCells(path)
176176
val newExpr = NewInstance(
177177
classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
178178

179179
val nullExpr = Literal.create(null, jvmRepr)
180180

181181
If(IsNull(path), nullExpr, newExpr)
182182
}
183+
}
184+
185+
case class ForTypedRow[G <: HList, H <: HList](
186+
)(implicit
187+
stage1: RecordEncoderStage1[G, H],
188+
classTag: ClassTag[TypedRow[G]])
189+
extends RecordEncoder[TypedRow[G], G, H] {
190+
191+
import stage1._
192+
193+
private final val _apply = "apply"
194+
private final val _fromInternalRow = "fromInternalRow"
195+
196+
def toCatalyst(path: Expression): Expression = {
197+
198+
val valueExprs = fields.value.value.zipWithIndex.map {
199+
case (field, i) =>
200+
val fieldPath = Invoke(
201+
path,
202+
_apply,
203+
field.encoder.jvmRepr,
204+
Seq(Literal.create(i, IntegerType))
205+
)
206+
field.encoder.toCatalyst(fieldPath)
207+
}
208+
209+
val createExpr = stage1.cellsToCatalyst(valueExprs)
210+
211+
val nullExpr = Literal.create(null, createExpr.dataType)
212+
213+
If(IsNull(path), nullExpr, createExpr)
214+
}
215+
216+
def fromCatalyst(path: Expression): Expression = {
217+
218+
val newArgs = stage1.fromCatalystToCells(path)
219+
val aggregated = CreateStruct(newArgs)
220+
221+
val partial = TypedRow.WithCatalystTypes(newArgs.map(_.dataType))
222+
223+
val newExpr = Invoke(
224+
Literal.fromObject(partial),
225+
_fromInternalRow,
226+
TypedRow.catalystType,
227+
Seq(aggregated)
228+
)
229+
230+
val nullExpr = Literal.create(null, jvmRepr)
231+
232+
If(IsNull(path), nullExpr, newExpr)
233+
}
234+
}
183235
}
184236

185237
final class RecordFieldEncoder[T](
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package frameless
2+
3+
import org.apache.spark.sql.catalyst.expressions.{
4+
CreateNamedStruct,
5+
Expression,
6+
GetStructField,
7+
Literal
8+
}
9+
import shapeless.{ HList, Lazy }
10+
11+
case class RecordEncoderStage1[G <: HList, H <: HList](
12+
)(implicit
13+
// i1: DropUnitValues.Aux[G, H],
14+
// i2: IsHCons[H],
15+
val fields: Lazy[RecordEncoderFields[H]],
16+
val newInstanceExprs: Lazy[NewInstanceExprs[G]]) {
17+
18+
def cellsToCatalyst(valueExprs: Seq[Expression]): Expression = {
19+
val nameExprs = fields.value.value.map { field => Literal(field.name) }
20+
21+
// the way exprs are encoded in CreateNamedStruct
22+
val exprs = nameExprs.zip(valueExprs).flatMap {
23+
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
24+
}
25+
26+
val createExpr = CreateNamedStruct(exprs)
27+
createExpr
28+
}
29+
30+
def fromCatalystToCells(path: Expression): Seq[Expression] = {
31+
val exprs = fields.value.value.map { field =>
32+
field.encoder.fromCatalyst(
33+
GetStructField(path, field.ordinal, Some(field.name))
34+
)
35+
}
36+
37+
val newArgs = newInstanceExprs.value.from(exprs)
38+
newArgs
39+
}
40+
}
41+
42+
object RecordEncoderStage1 {
43+
44+
implicit def usingDerivation[G <: HList, H <: HList](
45+
implicit
46+
i3: Lazy[RecordEncoderFields[H]],
47+
i4: Lazy[NewInstanceExprs[G]]
48+
): RecordEncoderStage1[G, H] = RecordEncoderStage1[G, H]()
49+
}

dataset/src/main/scala/frameless/TypedEncoder.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,15 +727,23 @@ object TypedEncoder {
727727
}
728728

729729
/** Encodes things as records if there is no Injection defined */
730-
implicit def usingDerivation[F, G <: HList, H <: HList](
730+
implicit def deriveForGeneric[F, G <: HList, H <: HList](
731731
implicit
732732
i0: LabelledGeneric.Aux[F, G],
733733
i1: DropUnitValues.Aux[G, H],
734734
i2: IsHCons[H],
735735
i3: Lazy[RecordEncoderFields[H]],
736736
i4: Lazy[NewInstanceExprs[G]],
737737
i5: ClassTag[F]
738-
): TypedEncoder[F] = new RecordEncoder[F, G, H]
738+
): TypedEncoder[F] = RecordEncoder.ForGeneric[F, G, H]()
739+
740+
implicit def deriveForTypedRow[G <: HList, H <: HList](
741+
implicit
742+
i1: DropUnitValues.Aux[G, H],
743+
i2: IsHCons[H],
744+
i3: Lazy[RecordEncoderFields[H]],
745+
i4: Lazy[NewInstanceExprs[G]]
746+
): TypedEncoder[TypedRow[G]] = RecordEncoder.ForTypedRow[G, H]()
739747

740748
/** Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit */
741749
implicit def usingUserDefinedType[
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package frameless
2+
3+
import org.apache.spark.sql.Row
4+
import org.apache.spark.sql.catalyst.InternalRow
5+
import org.apache.spark.sql.types.{ DataType, ObjectType }
6+
import shapeless.HList
7+
8+
case class TypedRow[T <: HList](row: Row) {
9+
10+
def apply(i: Int): Any = row.apply(i)
11+
}
12+
13+
object TypedRow {
14+
15+
def apply(values: Any*): TypedRow[HList] = {
16+
17+
val row = Row.fromSeq(values)
18+
TypedRow(row)
19+
}
20+
21+
case class WithCatalystTypes(schema: Seq[DataType]) {
22+
23+
def fromInternalRow(row: InternalRow): TypedRow[HList] = {
24+
val data = row.toSeq(schema).toArray
25+
26+
apply(data: _*)
27+
}
28+
29+
}
30+
31+
object WithCatalystTypes {}
32+
33+
def fromHList[T <: HList](
34+
hlist: T
35+
): TypedRow[T] = {
36+
37+
val cells = hlist.runtimeList
38+
39+
val row = Row.fromSeq(cells)
40+
TypedRow(row)
41+
}
42+
43+
lazy val catalystType: ObjectType = ObjectType(classOf[TypedRow[_]])
44+
45+
}

dataset/src/test/scala/frameless/InjectionTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ class InjectionTests extends TypedDatasetSuite {
180180
}
181181

182182
test("Resolve ambiguity by importing usingDerivation") {
183-
import TypedEncoder.usingDerivation
183+
import TypedEncoder.deriveForGeneric
184184
assert(implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]])
185185
check(forAll(prop[Person] _))
186186
}

dataset/src/test/scala/frameless/RecordEncoderTests.scala

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,12 @@
11
package frameless
22

3-
import org.apache.spark.sql.{Row, functions => F}
4-
import org.apache.spark.sql.types.{
5-
ArrayType,
6-
BinaryType,
7-
DecimalType,
8-
IntegerType,
9-
LongType,
10-
MapType,
11-
ObjectType,
12-
StringType,
13-
StructField,
14-
StructType
15-
}
16-
17-
import shapeless.{HList, LabelledGeneric}
18-
import shapeless.test.illTyped
19-
3+
import frameless.RecordEncoderTests.{ A, B, E }
4+
import org.apache.spark.sql.types._
5+
import org.apache.spark.sql.{ Row, functions => F }
206
import org.scalatest.matchers.should.Matchers
7+
import shapeless.record.Record
8+
import shapeless.test.illTyped
9+
import shapeless.{ HList, LabelledGeneric }
2110

2211
final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
2312
test("Unable to encode products made from units only") {
@@ -87,6 +76,26 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
8776
ds.collect.head shouldBe obj
8877
}
8978

79+
test("TypedRow") {
80+
81+
val r1: RecordEncoderTests.RR = Record(x = 1, y = "abc")
82+
val r2: TypedRow[RecordEncoderTests.RR] = TypedRow.fromHList(r1)
83+
84+
val rdd = sc.parallelize(Seq(r2))
85+
val ds =
86+
session.createDataset(rdd)(
87+
TypedExpressionEncoder[TypedRow[RecordEncoderTests.RR]]
88+
)
89+
90+
ds.schema.treeString shouldBe
91+
"""root
92+
| |-- x: integer (nullable = true)
93+
| |-- y: string (nullable = true)
94+
|""".stripMargin
95+
96+
ds.collect.head shouldBe r2
97+
}
98+
9099
test("Scalar value class") {
91100
import RecordEncoderTests._
92101

@@ -548,6 +557,9 @@ object RecordEncoderTests {
548557
case class D(m: Map[String, Int])
549558
case class E(b: Set[B])
550559

560+
val RR = Record.`'x -> Int, 'y -> String`
561+
type RR = RR.T
562+
551563
final class Subject(val name: String) extends AnyVal with Serializable
552564

553565
final class Grade(val value: BigDecimal) extends AnyVal with Serializable

refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ object RefinedTypesTests {
114114

115115
import frameless.refined._ // implicit instances for refined
116116

117-
implicit val encoderA: TypedEncoder[A] = TypedEncoder.usingDerivation
117+
implicit val encoderA: TypedEncoder[A] = TypedEncoder.deriveForGeneric
118118

119-
implicit val encoderB: TypedEncoder[B] = TypedEncoder.usingDerivation
119+
implicit val encoderB: TypedEncoder[B] = TypedEncoder.deriveForGeneric
120120
}

0 commit comments

Comments
 (0)