-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Labels
questionFurther information is requestedFurther information is requested
Description
I am trying to create a custom datafit for Tweedie GLM. I followed the tutorial for Poisson but it errors out with a not so useful error statement. See below the implementation and error.
from numba import jit
import numpy as np
from numpy.linalg import norm
from numba import njit
from numba import float64, int64, bool_
from skglm.datafits.base import BaseDatafit
from skglm.utils.sparse_ops import spectral_norm, _sparse_xj_dot
class Tweedie(BaseDatafit):
r"""Tweedie datafit.
The datafit reads:
.. math::
1 / n_"samples" \sum_(i=1)^(n_"samples") \left(\frac{y_i^{2-p}}{(1-p)(2-p)} - \frac{y_i e^{(1-p) Xw_i}}{1-p} + \frac{e^{(2-p) Xw_i}}{2-p}\right)
"""
def __init__(self, p=1.5):
self.p = p
def get_spec(self):
pass
def params_to_dict(self):
return dict(p=self.p)
def initialize(self, X, y):
if np.any(y < 0):
raise ValueError(
"Target vector `y` should only take non-negative values "
"when fitting a Tweedie model.")
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
if np.any(y < 0):
raise ValueError(
"Target vector `y` should only take non-negative values "
"when fitting a Tweedie model.")
@jit
def raw_grad(self, y, Xw):
"""Compute gradient of datafit w.r.t ``Xw``."""
p = self.p
return ((np.exp((1 - p) * Xw) - y * np.exp((2 - p) * Xw)) / len(y))
@jit
def raw_hessian(self, y, Xw):
"""Compute Hessian of datafit w.r.t ``Xw``."""
p = self.p
return ((1 - p) * np.exp((1 - p) * Xw) + (2 - p) * y * np.exp((2 - p) * Xw)) / len(y)
@jit
def value(self, y, w, Xw):
p = self.p
term1 = y**(2-p) / ((1-p)*(2-p))
term2 = y * np.exp((1-p) * Xw) / (1-p)
term3 = np.exp((2-p) * Xw) / (2-p)
return (np.sum(term1 - term2 + term3)) / len(y)
@jit
def gradient(self, X, y, Xw):
return X.T @ self.raw_grad(y, Xw)
@jit
def gradient_scalar(self, X, y, w, Xw, j):
return (X[:, j] @ self.raw_grad(y, Xw)) / len(y)
@jit
def full_grad_sparse(self, X_data, X_indptr, X_indices, y, Xw):
n_features = X_indptr.shape[0] - 1
grad = np.zeros(n_features, dtype=X_data.dtype)
for j in range(n_features):
grad[j] = 0.
for i in range(X_indptr[j], X_indptr[j + 1]):
grad[j] += X_data[i] * self.raw_grad(y, Xw)[X_indices[i]]
return grad
@jit
def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
grad = 0.
for i in range(X_indptr[j], X_indptr[j + 1]):
idx_i = X_indices[i]
grad += X_data[i] * self.raw_grad(y, Xw)[idx_i]
return grad / len(y)
@jit
def intercept_update_step(self, y, Xw):
return np.sum(self.raw_grad(y, Xw))
from skglm.solvers import GramCD
from skglm.penalties import L1_plus_L2
from skglm.estimators import GeneralizedLinearEstimator as SKGLM
skmod = SKGLM(
datafit=Tweedie(p=1.5),
penalty=L1_plus_L2(alpha=1.0, l1_ratio=1.0),
solver=GramCD(max_iter=1000, fit_intercept=True)
)
skmod.fit(X, y)
I get the error message below:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[63], line 7
1 skmod = SKGLM(
2 datafit=Tweedie(p=1.5),
3 penalty=L1_plus_L2(alpha=1.0, l1_ratio=1.0),
4 solver=GramCD(max_iter=1000, fit_intercept=True)
5 )
----> 7 skmod.fit(X_trans, y)
File /blah/blah/miniconda/envs/ipykernel_python3.9/lib/python3.9/site-packages/skglm/estimators.py:252, in GeneralizedLinearEstimator.fit(self, X, y)
249 self.datafit = self.datafit if self.datafit else Quadratic()
250 self.solver = self.solver if self.solver else AndersonCD()
--> 252 return _glm_fit(X, y, self, self.datafit, self.penalty, self.solver)
File /blah/blah/miniconda/envs/ipykernel_python3.9/lib/python3.9/site-packages/skglm/estimators.py:104, in _glm_fit(X, y, model, datafit, penalty, solver)
101 n_samples, n_features = X_.shape
103 penalty_jit = compiled_clone(penalty)
--> 104 datafit_jit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
105 if issparse(X):
106 datafit_jit.initialize_sparse(X_.data, X_.indptr, X_.indices, y)
File /blah/blah/miniconda/envs/ipykernel_python3.9/lib/python3.9/site-packages/skglm/utils/jit_compilation.py:79, in compiled_clone(instance, to_float32)
63 def compiled_clone(instance, to_float32=False):
64 """Compile instance to a jitclass.
65
66 Parameters
(...)
77 Return a jitclass.
78 """
---> 79 return jit_cached_compile(
80 instance.__class__,
81 instance.get_spec(),
82 to_float32,
83 )(**instance.params_to_dict())
File /blah/blah/miniconda/envs/ipykernel_python3.9/lib/python3.9/site-packages/skglm/utils/jit_compilation.py:60, in jit_cached_compile(klass, spec, to_float32)
57 if to_float32:
58 spec = spec_to_float32(spec)
---> 60 return jitclass(spec)(klass)
File /blah/blah/miniconda/envs/ipykernel_python3.9/lib/python3.9/site-packages/numba/experimental/jitclass/decorators.py:77, in jitclass.<locals>.wrap(cls)
74 else:
75 from numba.experimental.jitclass.base import (register_class_type,
76 ClassBuilder)
---> 77 cls_jitted = register_class_type(cls, spec, types.ClassType,
78 ClassBuilder)
80 # Preserve the module name of the original class
81 cls_jitted.__module__ = cls.__module__
File /blah/blah/miniconda/envs/ipykernel_python3.9/lib/python3.9/site-packages/numba/experimental/jitclass/base.py:213, in register_class_type(cls, spec, class_ctor, builder)
211 msg = "class members are not yet supported: {0}"
212 members = ', '.join(others.keys())
--> 213 raise TypeError(msg.format(members))
215 for k, v in props.items():
216 if v.fdel is not None:
TypeError: class members are not yet supported: value, raw_grad, raw_hessian, gradient, gradient_scalar, full_grad_sparse, gradient_scalar_sparse, intercept_update_step
PABannier
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested