Skip to content

Commit 9f0886f

Browse files
committed
Merge branch 'ig/full_branch' into ig/update_api
2 parents 6d3c0e4 + 8f07947 commit 9f0886f

File tree

16 files changed

+485
-11
lines changed

16 files changed

+485
-11
lines changed

batchglm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from . import base_glm, glm_beta, glm_nb, glm_norm
1+
from . import base_glm, glm_beta, glm_nb, glm_norm, glm_poisson

batchglm/models/base_glm/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
input_data: Optional[InputDataGLM] = None,
5555
):
5656
"""
57-
Create a new _ModelGLM object.
57+
Create a new ModelGLM object.
5858
5959
:param input_data: Input data for the model
6060
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import Model
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import batchglm.utils.data as data_utils
2+
from batchglm import pkg_constants
3+
from batchglm.models.base_glm import ModelGLM, closedform_glm_mean, closedform_glm_scale
4+
from batchglm.utils.linalg import groupwise_solve_lm

batchglm/models/glm_poisson/model.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import abc
2+
from typing import Any, Callable, Dict, Optional, Tuple, Union
3+
4+
import dask.array
5+
import numpy as np
6+
7+
from .external import ModelGLM
8+
9+
10+
class Model(ModelGLM, metaclass=abc.ABCMeta):
11+
"""
12+
Generalized Linear Model (GLM) with Poisson noise.
13+
"""
14+
15+
def link_loc(self, data) -> Union[np.ndarray, dask.array.core.Array]:
16+
return np.log(data)
17+
18+
def inverse_link_loc(self, data) -> Union[np.ndarray, dask.array.core.Array]:
19+
return np.exp(data)
20+
21+
def link_scale(self, data) -> Union[np.ndarray, dask.array.core.Array]:
22+
return np.log(data)
23+
24+
def inverse_link_scale(self, data) -> Union[np.ndarray, dask.array.core.Array]:
25+
return np.exp(data)
26+
27+
@property
28+
def eta_loc(self) -> Union[np.ndarray, dask.array.core.Array]:
29+
eta = np.matmul(self.design_loc, self.theta_location_constrained)
30+
if self.size_factors is not None:
31+
eta += self.size_factors
32+
eta = self.np_clip_param(eta, "eta_loc")
33+
return eta
34+
35+
def eta_loc_j(self, j) -> Union[np.ndarray, dask.array.core.Array]:
36+
# Make sure that dimensionality of sliced array is kept:
37+
if isinstance(j, int) or isinstance(j, np.int32) or isinstance(j, np.int64):
38+
j = [j]
39+
eta = np.matmul(self.design_loc, self.theta_location_constrained[:, j])
40+
if self.size_factors is not None:
41+
eta += self.size_factors
42+
eta = self.np_clip_param(eta, "eta_loc")
43+
return eta
44+
45+
# Re-parameterizations:
46+
47+
@property
48+
def lam(self) -> Union[np.ndarray, dask.array.core.Array]:
49+
return self.location
50+
51+
# param constraints:
52+
53+
def bounds(self, sf, dmax, dtype) -> Tuple[Dict[str, Any], Dict[str, Any]]:
54+
55+
bounds_min = {
56+
"theta_location": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
57+
"eta_loc": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
58+
"loc": np.nextafter(0, np.inf, dtype=dtype),
59+
"scale": np.nextafter(0, np.inf, dtype=dtype),
60+
"likelihood": dtype(0),
61+
"ll": np.log(np.nextafter(0, np.inf, dtype=dtype)),
62+
# Not used and should be removed: https://github.com/theislab/batchglm/issues/148
63+
"theta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
64+
"eta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
65+
}
66+
bounds_max = {
67+
"theta_location": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf,
68+
"eta_loc": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf,
69+
"loc": np.nextafter(dmax, -np.inf, dtype=dtype) / sf,
70+
"scale": np.nextafter(dmax, -np.inf, dtype=dtype) / sf,
71+
"likelihood": dtype(1),
72+
"ll": dtype(0),
73+
# Not used and should be removed: https://github.com/theislab/batchglm/issues/148
74+
"theta_scale": np.log(dmax) / sf,
75+
"eta_scale": np.log(dmax) / sf,
76+
77+
}
78+
return bounds_min, bounds_max
79+
80+
# simulator:
81+
82+
@property
83+
def rand_fn_ave(self) -> Optional[Callable]:
84+
return lambda shape: np.random.poisson(500, shape) + 1
85+
86+
@property
87+
def rand_fn(self) -> Optional[Callable]:
88+
return lambda shape: np.abs(np.random.uniform(0.5, 2, shape))
89+
90+
@property
91+
def rand_fn_loc(self) -> Optional[Callable]:
92+
return None
93+
94+
@property
95+
def rand_fn_scale(self) -> Optional[Callable]:
96+
return None
97+
98+
def generate_data(self) -> np.ndarray:
99+
"""
100+
Sample random data based on poisson distribution and parameters.
101+
"""
102+
return np.random.poisson(lam=self.lam)

batchglm/models/glm_poisson/utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import logging
2+
from typing import Callable, Optional, Tuple, Union
3+
4+
import dask
5+
import numpy as np
6+
import scipy.sparse
7+
8+
from .external import closedform_glm_mean
9+
10+
logger = logging.getLogger("batchglm")
11+
12+
13+
def closedform_poisson_glm_loglam(
14+
x: Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array],
15+
design_loc: Union[np.ndarray, dask.array.core.Array],
16+
constraints_loc: Union[np.ndarray, dask.array.core.Array],
17+
size_factors: Optional[np.ndarray] = None,
18+
link_fn: Callable = np.log,
19+
inv_link_fn: Callable = np.exp,
20+
):
21+
r"""
22+
Calculates a closed-form solution for the `lam` parameters of poisson GLMs.
23+
24+
:param x: The sample data
25+
:param design_loc: design matrix for location
26+
:param constraints_loc: tensor (all parameters x dependent parameters)
27+
Tensor that encodes how complete parameter set which includes dependent
28+
parameters arises from indepedent parameters: all = <constraints, indep>.
29+
This form of constraints is used in vector generalized linear models (VGLMs).
30+
:param size_factors: size factors for X
31+
:return: tuple: (groupwise_means, mu, rmsd)
32+
"""
33+
return closedform_glm_mean(
34+
x=x,
35+
dmat=design_loc,
36+
constraints=constraints_loc,
37+
size_factors=size_factors,
38+
link_fn=link_fn,
39+
inv_link_fn=inv_link_fn,
40+
)
41+
42+
43+
def init_par(model, init_location: str) -> Tuple[np.ndarray, np.ndarray, bool, bool]:
44+
r"""
45+
standard:
46+
Only initialise intercept and keep other coefficients as zero.
47+
48+
closed-form:
49+
Initialize with Maximum Likelihood / Maximum of Momentum estimators
50+
51+
Idea:
52+
$$
53+
\theta &= f(x) \\
54+
\Rightarrow f^{-1}(\theta) &= x \\
55+
&= (D \cdot D^{+}) \cdot x \\
56+
&= D \cdot (D^{+} \cdot x) \\
57+
&= D \cdot x' = f^{-1}(\theta)
58+
$$
59+
"""
60+
train_loc = False
61+
62+
def auto_loc(dmat: Union[np.ndarray, dask.array.core.Array]) -> str:
63+
"""
64+
Checks if dmat is one-hot encoded and returns 'closed_form' if so, else 'standard'
65+
66+
:param dmat The design matrix to check.
67+
"""
68+
unique_params = np.unique(dmat)
69+
if isinstance(unique_params, dask.array.core.Array):
70+
unique_params = unique_params.compute()
71+
if len(unique_params) == 2 and unique_params[0] == 0.0 and unique_params[1] == 1.0:
72+
return "closed_form"
73+
logger.warning(
74+
(
75+
"Cannot use 'closed_form' init for loc model: "
76+
"design_loc is not one-hot encoded. Falling back to standard initialization."
77+
)
78+
)
79+
return "standard"
80+
81+
groupwise_means = None
82+
83+
init_location_str = init_location.lower()
84+
# Chose option if auto was chosen
85+
if init_location_str == "auto":
86+
87+
init_location_str = auto_loc(model.design_loc)
88+
89+
if init_location_str == "closed_form":
90+
groupwise_means, init_theta_location, rmsd_a = closedform_poisson_glm_loglam(
91+
x=model.x,
92+
design_loc=model.design_loc,
93+
constraints_loc=model.constraints_loc,
94+
size_factors=model.size_factors,
95+
link_fn=lambda lam: np.log(lam + np.nextafter(0, 1, dtype=lam.dtype)),
96+
)
97+
# train mu, if the closed-form solution is inaccurate
98+
train_loc = not (np.all(np.abs(rmsd_a) < 1e-20) or rmsd_a.size == 0)
99+
if model.size_factors is not None:
100+
if np.any(model.size_factors != 1):
101+
train_loc = True
102+
103+
elif init_location_str == "standard":
104+
overall_means = np.mean(model.x, axis=0) # directly calculate the mean
105+
init_theta_location = np.zeros([model.num_loc_params, model.num_features])
106+
init_theta_location[0, :] = np.log(overall_means)
107+
train_loc = True
108+
elif init_location_str == "all_zero":
109+
init_theta_location = np.zeros([model.num_loc_params, model.num_features])
110+
train_loc = True
111+
else:
112+
raise ValueError("init_location string %s not recognized" % init_location)
113+
114+
# Scale is not used so just return init_theta_location for what would be init_theta_scale
115+
return init_theta_location, init_theta_location, train_loc, True

batchglm/train/numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from . import glm_nb as nb
1+
from . import glm_nb, glm_poisson
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .estimator import Estimator
2+
from .model_container import ModelContainer
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import sys
2+
from typing import Optional, Tuple, Union
3+
4+
import numpy as np
5+
6+
from .external import EstimatorGlm, Model, init_par
7+
from .model_container import ModelContainer
8+
9+
10+
class Estimator(EstimatorGlm):
11+
"""
12+
Estimator for Generalized Linear Models (GLMs) with negative binomial noise.
13+
Uses the natural logarithm as linker function.
14+
15+
Attributes
16+
----------
17+
model_vars : ModelVars
18+
model variables
19+
"""
20+
21+
def __init__(
22+
self,
23+
model: Model,
24+
init_location: str = "AUTO",
25+
init_scale: str = "AUTO",
26+
# batch_size: Optional[Union[Tuple[int, int], int]] = None,
27+
quick_scale: bool = False,
28+
dtype: str = "float64",
29+
):
30+
"""
31+
Performs initialisation and creates a new estimator.
32+
33+
:param init_location: (Optional)
34+
Low-level initial values for a. Can be:
35+
36+
- str:
37+
* "auto": automatically choose best initialization
38+
* "standard": initialize intercept with observed mean
39+
* "init_model": initialize with another model (see `ìnit_model` parameter)
40+
* "closed_form": try to initialize with closed form
41+
- np.ndarray: direct initialization of 'a'
42+
:param dtype: Numerical precision.
43+
"""
44+
init_theta_location, _, train_loc, _ = init_par(model=model, init_location=init_location)
45+
self._train_loc = train_loc
46+
# no need to train the scale parameter for the poisson model since it only has one parameter
47+
self._train_scale = False
48+
sys.stdout.write("training location model: %s\n" % str(self._train_loc))
49+
init_theta_location = init_theta_location.astype(dtype)
50+
51+
_model_container = ModelContainer(
52+
model=model,
53+
init_theta_location=init_theta_location,
54+
init_theta_scale=init_theta_location, # Not used.
55+
chunk_size_genes=model.chunk_size_genes,
56+
dtype=dtype,
57+
)
58+
super(Estimator, self).__init__(model_container=_model_container, dtype=dtype)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class NoScaleError(Exception):
2+
"""
3+
Exception raised for attempting to access the scale parameter (or one of its derived methods) of a poisson model.
4+
"""
5+
6+
def __init__(self, method):
7+
self.message = f"Attempted to access {method}. No scale parameter is fit for poisson - please use location."
8+
super().__init__(self.message)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import batchglm.utils.data as data_utils
2+
from batchglm import pkg_constants
3+
from batchglm.models.base_glm.utils import closedform_glm_mean, closedform_glm_scale
4+
from batchglm.models.glm_poisson.model import Model
5+
from batchglm.models.glm_poisson.utils import init_par
6+
7+
# import necessary base_glm layers
8+
from batchglm.train.numpy.base_glm import EstimatorGlm, NumpyModelContainer
9+
from batchglm.utils.linalg import groupwise_solve_lm

0 commit comments

Comments
 (0)