diff --git a/src/centimators/losses.py b/src/centimators/losses.py index 124155f..a268aac 100644 --- a/src/centimators/losses.py +++ b/src/centimators/losses.py @@ -16,8 +16,10 @@ import keras.ops as K from keras.losses import Loss from keras.config import epsilon +from keras.saving import register_keras_serializable +@register_keras_serializable(package="centimators") class SpearmanCorrelation(Loss): """Differentiable Spearman rank correlation loss. @@ -114,7 +116,13 @@ def _correlation(self, x, y): return numerator / denominator + def get_config(self): + config = super().get_config() + config.update({"regularization_strength": self.regularization_strength}) + return config + +@register_keras_serializable(package="centimators") class CombinedLoss(Loss): """Weighted combination of MSE and Spearman correlation losses. @@ -168,3 +176,14 @@ def call(self, y_true, y_pred): spearman = self.spearman_loss(y_true, y_pred) return self.mse_weight * mse + self.spearman_weight * spearman + + def get_config(self): + config = super().get_config() + config.update( + { + "mse_weight": self.mse_weight, + "spearman_weight": self.spearman_weight, + "spearman_regularization": self.spearman_loss.regularization_strength, + } + ) + return config diff --git a/src/centimators/model_estimators/keras_estimators/tree.py b/src/centimators/model_estimators/keras_estimators/tree.py index 2638905..306b1d8 100644 --- a/src/centimators/model_estimators/keras_estimators/tree.py +++ b/src/centimators/model_estimators/keras_estimators/tree.py @@ -15,6 +15,7 @@ from .base import BaseKerasEstimator from keras import layers, models, ops as K, regularizers, callbacks, initializers +from keras.saving import register_keras_serializable class TemperatureAnnealing(callbacks.Callback): @@ -49,6 +50,7 @@ def on_epoch_end(self, epoch, logs=None): tree.temperature.assign(t) +@register_keras_serializable(package="centimators") class NeuralDecisionTree(models.Model): """A differentiable decision tree with stochastic routing. @@ -93,11 +95,18 @@ def __init__( l2_leaf=1e-3, temperature=0.5, rng=None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) + # Store config params for serialization self.depth = depth + self.num_features = num_features + self.used_features_rate = used_features_rate self.num_leaves = 2**depth self.output_units = output_units + self.l2_decision = l2_decision + self.l2_leaf = l2_leaf + self._init_temperature = temperature # Store initial value for get_config # Create a mask for the randomly selected features num_used_features = max(1, int(round(num_features * used_features_rate))) @@ -129,9 +138,7 @@ def __init__( self.temperature = self.add_weight( name="temperature", shape=(), - initializer=lambda shape, dtype: K.convert_to_tensor( - temperature, dtype=dtype - ), + initializer=initializers.Constant(temperature), trainable=False, ) @@ -145,6 +152,21 @@ def __init__( else None, ) + def get_config(self): + config = super().get_config() + config.update( + { + "depth": self.depth, + "num_features": self.num_features, + "used_features_rate": self.used_features_rate, + "output_units": self.output_units, + "l2_decision": self.l2_decision, + "l2_leaf": self.l2_leaf, + "temperature": self._init_temperature, + } + ) + return config + def call(self, features): batch_size = K.shape(features)[0]