1616
1717package com .johnsnowlabs .nlp
1818
19+ import com .johnsnowlabs .nlp .LegacyMetadataSupport .ParamsReflection
20+ import org .apache .hadoop .fs .Path
21+ import org .apache .spark .internal .Logging
22+ import org .apache .spark .ml .param .Params
1923import org .apache .spark .ml .util .{DefaultParamsReadable , MLReader }
2024import org .apache .spark .sql .SparkSession
25+ import org .json4s .jackson .JsonMethods .{compact , parse , render }
26+ import org .json4s .{DefaultFormats , JNothing , JNull , JObject , JValue }
2127
2228import scala .collection .mutable .ArrayBuffer
2329
2430class FeaturesReader [T <: HasFeatures ](
2531 baseReader : MLReader [T ],
2632 onRead : (T , String , SparkSession ) => Unit )
27- extends MLReader [T ] {
33+ extends MLReader [T ]
34+ with Logging {
2835
2936 override def load (path : String ): T = {
3037
31- val instance = baseReader.load(path)
38+ val instance =
39+ try {
40+ // Let Spark's own loader handle modern bundles.
41+ baseReader.load(path)
42+ } catch {
43+ case e : NoSuchElementException if isMissingParamError(e) =>
44+ // Reconstruct legacy models that referenced params removed in newer releases.
45+ loadWithLegacyParams(path)
46+ }
3247
3348 for (feature <- instance.features) {
3449 val value = feature.deserialize(sparkSession, path, feature.name)
@@ -39,6 +54,59 @@ class FeaturesReader[T <: HasFeatures](
3954
4055 instance
4156 }
57+
58+ private def isMissingParamError (e : NoSuchElementException ): Boolean = {
59+ val msg = Option (e.getMessage).getOrElse(" " )
60+ msg.contains(" Param" )
61+ }
62+
63+ private def loadWithLegacyParams (path : String ): T = {
64+ val metadata = LegacyMetadataSupport .load(path, sparkSession)
65+ val cls = Class .forName(metadata.className)
66+ val ctor = cls.getConstructor(classOf [String ])
67+ val instance = ctor.newInstance(metadata.uid).asInstanceOf [Params ]
68+ setParamsIgnoringUnknown(instance, metadata)
69+ instance.asInstanceOf [T ]
70+ }
71+
72+ private def setParamsIgnoringUnknown (
73+ instance : Params ,
74+ metadata : LegacyMetadataSupport .Metadata ): Unit = {
75+ // Replay active params; skip mismatches so legacy bundles still come back.
76+ assignParams(instance, metadata.params, isDefault = false , metadata)
77+
78+ val hasDefaultSection = metadata.defaultParams != JNothing && metadata.defaultParams != JNull
79+ if (hasDefaultSection) {
80+ // If the metadata carried defaults, restore only those that still exists.
81+ assignParams(instance, metadata.defaultParams, isDefault = true , metadata)
82+ }
83+ }
84+
85+ private def assignParams (
86+ instance : Params ,
87+ jsonParams : JValue ,
88+ isDefault : Boolean ,
89+ metadata : LegacyMetadataSupport .Metadata ): Unit = {
90+ jsonParams match {
91+ case JObject (pairs) =>
92+ pairs.foreach { case (paramName, jsonValue) =>
93+ if (instance.hasParam(paramName)) {
94+ val param = instance.getParam(paramName)
95+ val value = param.jsonDecode(compact(render(jsonValue)))
96+ if (isDefault) {
97+ // Spark keeps setDefault protected; call it via reflection to restore legacy defaults.
98+ ParamsReflection .setDefault(instance, param, value)
99+ } else {
100+ instance.set(param, value)
101+ }
102+ }
103+ }
104+ case JNothing | JNull =>
105+ case other =>
106+ throw new IllegalArgumentException (
107+ s " Cannot recognize JSON metadata when loading legacy params for ${metadata.className}: $other" )
108+ }
109+ }
42110}
43111
44112trait ParamsAndFeaturesReadable [T <: HasFeatures ] extends DefaultParamsReadable [T ] {
@@ -60,3 +128,79 @@ trait ParamsAndFeaturesReadable[T <: HasFeatures] extends DefaultParamsReadable[
60128 super .read,
61129 (instance : T , path : String , spark : SparkSession ) => onRead(instance, path, spark))
62130}
131+
132+ // Minimal metadata parser + helper utilities for replaying legacy params.
133+ protected object LegacyMetadataSupport {
134+
135+ object ParamsReflection {
136+ private val setDefaultMethod = {
137+ val maybeMethod = classOf [Params ].getDeclaredMethods.find { method =>
138+ method.getName == " setDefault" && method.getParameterCount == 2
139+ }
140+
141+ maybeMethod match {
142+ case Some (method) =>
143+ method.setAccessible(true )
144+ method
145+ case None =>
146+ throw new NoSuchMethodException (" Params.setDefault(Param, value) not found via reflection" )
147+ }
148+ }
149+
150+ def setDefault [T ](
151+ params : Params ,
152+ param : org.apache.spark.ml.param.Param [T ],
153+ value : T ): Unit = {
154+ setDefaultMethod.invoke(params, param, toAnyRef(value))
155+ }
156+
157+ // Mirror JVM boxing rules so reflection can call the protected method safely.
158+ private def toAnyRef (value : Any ): AnyRef = {
159+ if (value == null ) {
160+ null
161+ } else {
162+ value match {
163+ case v : AnyRef => v
164+ case v : Boolean => java.lang.Boolean .valueOf(v)
165+ case v : Byte => java.lang.Byte .valueOf(v)
166+ case v : Short => java.lang.Short .valueOf(v)
167+ case v : Int => java.lang.Integer .valueOf(v)
168+ case v : Long => java.lang.Long .valueOf(v)
169+ case v : Float => java.lang.Float .valueOf(v)
170+ case v : Double => java.lang.Double .valueOf(v)
171+ case v : Char => java.lang.Character .valueOf(v)
172+ case other =>
173+ throw new IllegalArgumentException (
174+ s " Unsupported default value type ${other.getClass}" )
175+ }
176+ }
177+ }
178+ }
179+
180+ case class Metadata (
181+ className : String ,
182+ uid : String ,
183+ sparkVersion : String ,
184+ params : JValue ,
185+ defaultParams : JValue ,
186+ metadataJson : String )
187+
188+ def load (path : String , spark : SparkSession ): Metadata = {
189+ val metadataPath = new Path (path, " metadata" ).toString
190+ val metadataStr = spark.sparkContext.textFile(metadataPath, 1 ).first()
191+ parseMetadata(metadataStr)
192+ }
193+
194+ private def parseMetadata (metadataStr : String ): Metadata = {
195+ val metadata = parse(metadataStr)
196+ implicit val format : DefaultFormats .type = DefaultFormats
197+
198+ val className = (metadata \ " class" ).extract[String ]
199+ val uid = (metadata \ " uid" ).extract[String ]
200+ val sparkVersion = (metadata \ " sparkVersion" ).extractOpt[String ].getOrElse(" 0.0" )
201+ val params = metadata \ " paramMap"
202+ val defaultParams = metadata \ " defaultParamMap"
203+
204+ Metadata (className, uid, sparkVersion, params, defaultParams, metadataStr)
205+ }
206+ }
0 commit comments