Skip to content

Commit 4955483

Browse files
committed
Adding changes to load models with complex objects
1 parent 41ae0fa commit 4955483

File tree

2 files changed

+168
-3
lines changed

2 files changed

+168
-3
lines changed

src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,34 @@
1616

1717
package com.johnsnowlabs.nlp
1818

19+
import com.johnsnowlabs.nlp.LegacyMetadataSupport.ParamsReflection
20+
import org.apache.hadoop.fs.Path
21+
import org.apache.spark.internal.Logging
22+
import org.apache.spark.ml.param.Params
1923
import org.apache.spark.ml.util.{DefaultParamsReadable, MLReader}
2024
import org.apache.spark.sql.SparkSession
25+
import org.json4s.jackson.JsonMethods.{compact, parse, render}
26+
import org.json4s.{DefaultFormats, JNothing, JNull, JObject, JValue}
2127

2228
import scala.collection.mutable.ArrayBuffer
2329

2430
class FeaturesReader[T <: HasFeatures](
2531
baseReader: MLReader[T],
2632
onRead: (T, String, SparkSession) => Unit)
27-
extends MLReader[T] {
33+
extends MLReader[T]
34+
with Logging {
2835

2936
override def load(path: String): T = {
3037

31-
val instance = baseReader.load(path)
38+
val instance =
39+
try {
40+
// Let Spark's own loader handle modern bundles.
41+
baseReader.load(path)
42+
} catch {
43+
case e: NoSuchElementException if isMissingParamError(e) =>
44+
// Reconstruct legacy models that referenced params removed in newer releases.
45+
loadWithLegacyParams(path)
46+
}
3247

3348
for (feature <- instance.features) {
3449
val value = feature.deserialize(sparkSession, path, feature.name)
@@ -39,6 +54,59 @@ class FeaturesReader[T <: HasFeatures](
3954

4055
instance
4156
}
57+
58+
private def isMissingParamError(e: NoSuchElementException): Boolean = {
59+
val msg = Option(e.getMessage).getOrElse("")
60+
msg.contains("Param")
61+
}
62+
63+
private def loadWithLegacyParams(path: String): T = {
64+
val metadata = LegacyMetadataSupport.load(path, sparkSession)
65+
val cls = Class.forName(metadata.className)
66+
val ctor = cls.getConstructor(classOf[String])
67+
val instance = ctor.newInstance(metadata.uid).asInstanceOf[Params]
68+
setParamsIgnoringUnknown(instance, metadata)
69+
instance.asInstanceOf[T]
70+
}
71+
72+
private def setParamsIgnoringUnknown(
73+
instance: Params,
74+
metadata: LegacyMetadataSupport.Metadata): Unit = {
75+
// Replay active params; skip mismatches so legacy bundles still come back.
76+
assignParams(instance, metadata.params, isDefault = false, metadata)
77+
78+
val hasDefaultSection = metadata.defaultParams != JNothing && metadata.defaultParams != JNull
79+
if (hasDefaultSection) {
80+
// If the metadata carried defaults, restore only those that still exists.
81+
assignParams(instance, metadata.defaultParams, isDefault = true, metadata)
82+
}
83+
}
84+
85+
private def assignParams(
86+
instance: Params,
87+
jsonParams: JValue,
88+
isDefault: Boolean,
89+
metadata: LegacyMetadataSupport.Metadata): Unit = {
90+
jsonParams match {
91+
case JObject(pairs) =>
92+
pairs.foreach { case (paramName, jsonValue) =>
93+
if (instance.hasParam(paramName)) {
94+
val param = instance.getParam(paramName)
95+
val value = param.jsonDecode(compact(render(jsonValue)))
96+
if (isDefault) {
97+
// Spark keeps setDefault protected; call it via reflection to restore legacy defaults.
98+
ParamsReflection.setDefault(instance, param, value)
99+
} else {
100+
instance.set(param, value)
101+
}
102+
}
103+
}
104+
case JNothing | JNull =>
105+
case other =>
106+
throw new IllegalArgumentException(
107+
s"Cannot recognize JSON metadata when loading legacy params for ${metadata.className}: $other")
108+
}
109+
}
42110
}
43111

44112
trait ParamsAndFeaturesReadable[T <: HasFeatures] extends DefaultParamsReadable[T] {
@@ -60,3 +128,79 @@ trait ParamsAndFeaturesReadable[T <: HasFeatures] extends DefaultParamsReadable[
60128
super.read,
61129
(instance: T, path: String, spark: SparkSession) => onRead(instance, path, spark))
62130
}
131+
132+
// Minimal metadata parser + helper utilities for replaying legacy params.
133+
protected object LegacyMetadataSupport {
134+
135+
object ParamsReflection {
136+
private val setDefaultMethod = {
137+
val maybeMethod = classOf[Params].getDeclaredMethods.find { method =>
138+
method.getName == "setDefault" && method.getParameterCount == 2
139+
}
140+
141+
maybeMethod match {
142+
case Some(method) =>
143+
method.setAccessible(true)
144+
method
145+
case None =>
146+
throw new NoSuchMethodException("Params.setDefault(Param, value) not found via reflection")
147+
}
148+
}
149+
150+
def setDefault[T](
151+
params: Params,
152+
param: org.apache.spark.ml.param.Param[T],
153+
value: T): Unit = {
154+
setDefaultMethod.invoke(params, param, toAnyRef(value))
155+
}
156+
157+
// Mirror JVM boxing rules so reflection can call the protected method safely.
158+
private def toAnyRef(value: Any): AnyRef = {
159+
if (value == null) {
160+
null
161+
} else {
162+
value match {
163+
case v: AnyRef => v
164+
case v: Boolean => java.lang.Boolean.valueOf(v)
165+
case v: Byte => java.lang.Byte.valueOf(v)
166+
case v: Short => java.lang.Short.valueOf(v)
167+
case v: Int => java.lang.Integer.valueOf(v)
168+
case v: Long => java.lang.Long.valueOf(v)
169+
case v: Float => java.lang.Float.valueOf(v)
170+
case v: Double => java.lang.Double.valueOf(v)
171+
case v: Char => java.lang.Character.valueOf(v)
172+
case other =>
173+
throw new IllegalArgumentException(
174+
s"Unsupported default value type ${other.getClass}")
175+
}
176+
}
177+
}
178+
}
179+
180+
case class Metadata(
181+
className: String,
182+
uid: String,
183+
sparkVersion: String,
184+
params: JValue,
185+
defaultParams: JValue,
186+
metadataJson: String)
187+
188+
def load(path: String, spark: SparkSession): Metadata = {
189+
val metadataPath = new Path(path, "metadata").toString
190+
val metadataStr = spark.sparkContext.textFile(metadataPath, 1).first()
191+
parseMetadata(metadataStr)
192+
}
193+
194+
private def parseMetadata(metadataStr: String): Metadata = {
195+
val metadata = parse(metadataStr)
196+
implicit val format: DefaultFormats.type = DefaultFormats
197+
198+
val className = (metadata \ "class").extract[String]
199+
val uid = (metadata \ "uid").extract[String]
200+
val sparkVersion = (metadata \ "sparkVersion").extractOpt[String].getOrElse("0.0")
201+
val params = metadata \ "paramMap"
202+
val defaultParams = metadata \ "defaultParamMap"
203+
204+
Metadata(className, uid, sparkVersion, params, defaultParams, metadataStr)
205+
}
206+
}

src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import com.johnsnowlabs.nlp.annotators.SparkSessionTest
2121
import com.johnsnowlabs.nlp.annotators.er.EntityRulerFixture._
2222
import com.johnsnowlabs.nlp.base.LightPipeline
2323
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs}
24-
import com.johnsnowlabs.tags.FastTest
24+
import com.johnsnowlabs.tags.{FastTest, SlowTest}
2525
import org.apache.spark.ml.{Pipeline, PipelineModel}
2626
import org.scalatest.flatspec.AnyFlatSpec
2727

@@ -850,4 +850,25 @@ class EntityRulerTest extends AnyFlatSpec with SparkSessionTest {
850850
entityRulerPipeline
851851
}
852852

853+
it should "serialize EntityRulerModel" taggedAs SlowTest in {
854+
//Should br run with Java 8 and Scala 2.12
855+
val entityRuler = new EntityRulerApproach()
856+
.setInputCols("document", "token")
857+
.setOutputCol("entities")
858+
.setPatternsResource("src/test/resources/entity-ruler/keywords_only.json", ReadAs.TEXT)
859+
val entityRulerModel = entityRuler.fit(emptyDataSet)
860+
861+
entityRulerModel.write.overwrite().save("./tmp_entity_ruler_model_java8_scala2_12")
862+
}
863+
864+
it should "deserialize EntityRulerModel" in {
865+
val textDataSet = Seq(text1).toDS.toDF("text")
866+
val loadedEntityRulerModel = EntityRulerModel.load("./tmp_entity_ruler_model_java8_scala2_12")
867+
868+
val pipeline =
869+
new Pipeline().setStages(Array(documentAssembler, tokenizer, loadedEntityRulerModel))
870+
val resultDf = pipeline.fit(emptyDataSet).transform(textDataSet)
871+
resultDf.select("entities").show(truncate = false)
872+
}
873+
853874
}

0 commit comments

Comments
 (0)