Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/centimators/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import keras.ops as K
from keras.losses import Loss
from keras.config import epsilon
from keras.saving import register_keras_serializable


@register_keras_serializable(package="centimators")
class SpearmanCorrelation(Loss):
"""Differentiable Spearman rank correlation loss.

Expand Down Expand Up @@ -114,7 +116,13 @@ def _correlation(self, x, y):

return numerator / denominator

def get_config(self):
config = super().get_config()
config.update({"regularization_strength": self.regularization_strength})
return config


@register_keras_serializable(package="centimators")
class CombinedLoss(Loss):
"""Weighted combination of MSE and Spearman correlation losses.

Expand Down Expand Up @@ -168,3 +176,14 @@ def call(self, y_true, y_pred):
spearman = self.spearman_loss(y_true, y_pred)

return self.mse_weight * mse + self.spearman_weight * spearman

def get_config(self):
config = super().get_config()
config.update(
{
"mse_weight": self.mse_weight,
"spearman_weight": self.spearman_weight,
"spearman_regularization": self.spearman_loss.regularization_strength,
}
)
return config
30 changes: 26 additions & 4 deletions src/centimators/model_estimators/keras_estimators/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .base import BaseKerasEstimator
from keras import layers, models, ops as K, regularizers, callbacks, initializers
from keras.saving import register_keras_serializable


class TemperatureAnnealing(callbacks.Callback):
Expand Down Expand Up @@ -49,6 +50,7 @@ def on_epoch_end(self, epoch, logs=None):
tree.temperature.assign(t)


@register_keras_serializable(package="centimators")
class NeuralDecisionTree(models.Model):
"""A differentiable decision tree with stochastic routing.

Expand Down Expand Up @@ -93,11 +95,18 @@ def __init__(
l2_leaf=1e-3,
temperature=0.5,
rng=None,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
# Store config params for serialization
self.depth = depth
self.num_features = num_features
self.used_features_rate = used_features_rate
self.num_leaves = 2**depth
self.output_units = output_units
self.l2_decision = l2_decision
self.l2_leaf = l2_leaf
self._init_temperature = temperature # Store initial value for get_config

# Create a mask for the randomly selected features
num_used_features = max(1, int(round(num_features * used_features_rate)))
Expand Down Expand Up @@ -129,9 +138,7 @@ def __init__(
self.temperature = self.add_weight(
name="temperature",
shape=(),
initializer=lambda shape, dtype: K.convert_to_tensor(
temperature, dtype=dtype
),
initializer=initializers.Constant(temperature),
trainable=False,
)

Expand All @@ -145,6 +152,21 @@ def __init__(
else None,
)

def get_config(self):
config = super().get_config()
config.update(
{
"depth": self.depth,
"num_features": self.num_features,
"used_features_rate": self.used_features_rate,
"output_units": self.output_units,
"l2_decision": self.l2_decision,
"l2_leaf": self.l2_leaf,
"temperature": self._init_temperature,
}
)
return config

def call(self, features):
batch_size = K.shape(features)[0]

Expand Down