Skip to content

Commit 81ca89d

Browse files
authored
API Updates for diffxpy (#146)
* [WIP] The initializer...does everything? * [WIP] Refactor to fit model, accuracy test passes. * [WIP] Rempve unecessary tests. * [WIP] Starting poisson model * [WIP] Need to log likelihood * [WIP] Beginning to handle log likelihood * Seems right? * Fill in functions that might be needed for testing. * Make tests more consistent. * Clean up model a bit. * [WIP] Fix log likelihood? * Remove unfinished methods * [WIP] Tests run but some don't pass * [WIP] Add error for scale model. * [WIP] Add normal scale modeling * [WIP] Fix normal model ll * Small fixes and clean up. * Fix Poisson Log-Likelihood * Fix missing export. * Fix init_par function * Fix init_par for scale. * Add docs. * Add docs. * Move to completed. * Fix typing issue. * Fix typing issue. * Start implementing base class. * Small changes * [WIP] More small changes for diffxpy * [WIP] Package can now be imported and technically run in diffxpy * [WIP] Add feature names to generated data. * Update API * Add/Remove more properties * Make signature consistent. * Fix unused argument * Fix categorical issue * Line length * Return term_names as coef_names to match matrix behavior when None * Use columns for as_categorical. * Use new container structure. * Fix normal model/model_container * Add setter functions to generate_artificial_data * Begin fixing normal model * Update jacobian * Use correct linker for scale. * Use constrained matrix for normal location model. * Fix FIM calculation * Fix jacobian calculation (?) * Fix size factors calculation * Refactor dask_compute * Fix sparse array issue with init_par. * Re pin deps. * Re pin deps. * Use generalized jacobian calculation. * Add blank line * Fix pre-commit * Try to satisfy mypy * Fix data tests. * Blakc again. * Update lock. * Fix pre-commit * Fix jacobian scale * Black. * Small test cleanups. * Move ll function. * Move ll function. * Ok remove! * Make small changes. * Add comment. * Fix pre-commit. * Remove bounds not needed. * Add back in scale clipping * Remove ll. * Remove dask_compute. * Fix. * Remove dask_compute. * Fix container. * Fix fim/hessian. * Small poisson fixes. * Try smaller average. * Fix init linked. * Formatting. * try different lam. * Formatting. * Try fixing lam issue on windows. * Actually fix problem. * Still too big? * Revert
1 parent 8022560 commit 81ca89d

39 files changed

+755
-1228
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ repos:
4141
types: [text]
4242
stages: [commit, push, manual]
4343
- repo: https://github.com/pre-commit/mirrors-prettier
44-
rev: v2.3.0
44+
rev: v2.7.1
4545
hooks:
4646
- id: prettier
4747
- repo: https://github.com/pycqa/isort
48-
rev: 5.8.0
48+
rev: 5.10.1
4949
hooks:
5050
- id: isort
5151
name: isort (python)

batchglm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from . import glm_beta, glm_nb, glm_norm, glm_poisson
1+
from . import base_glm, glm_beta, glm_nb, glm_norm, glm_poisson

batchglm/models/base_glm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# from .estimator import _EstimatorGLM
22
from ...utils.input import InputDataGLM
3-
from .model import _ModelGLM
4-
from .utils import closedform_glm_mean, closedform_glm_scale, parse_design
3+
from .model import ModelGLM
4+
from .utils import closedform_glm_mean, closedform_glm_scale

batchglm/models/base_glm/model.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
import abc
22
import logging
3+
import random
4+
import string
35
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
46

57
import dask.array
68
import numpy as np
9+
import pandas as pd
710
import scipy
811

912
from ...utils.input import InputDataGLM
1013
from .external import pkg_constants
11-
from .utils import generate_sample_description, parse_constraints, parse_design
14+
from .utils import generate_sample_description
1215

1316
logger = logging.getLogger(__name__)
1417

1518

16-
class _ModelGLM(metaclass=abc.ABCMeta):
19+
class ModelGLM(metaclass=abc.ABCMeta):
1720
"""
1821
Generalized Linear Model base class.
1922
@@ -43,13 +46,15 @@ class _ModelGLM(metaclass=abc.ABCMeta):
4346
_cast_dtype: str = "float32"
4447
_chunk_size_cells: int
4548
_chunk_size_genes: int
49+
_sample_description: pd.DataFrame
50+
_features: List[str]
4651

4752
def __init__(
4853
self,
4954
input_data: Optional[InputDataGLM] = None,
5055
):
5156
"""
52-
Create a new _ModelGLM object.
57+
Create a new ModelGLM object.
5358
5459
:param input_data: Input data for the model
5560
@@ -72,9 +77,14 @@ def extract_input_data(self, input_data: InputDataGLM):
7277
self._cast_dtype = input_data.cast_dtype
7378
self._chunk_size_genes = input_data.chunk_size_genes
7479
self._chunk_size_cells = input_data.chunk_size_cells
80+
self._features = input_data.features
7581
self._xh_loc = np.matmul(self.design_loc, self.constraints_loc)
7682
self._xh_scale = np.matmul(self.design_scale, self.constraints_scale)
7783

84+
@property
85+
def features(self) -> List[str]:
86+
return self._features
87+
7888
@property
7989
def chunk_size_cells(self) -> int:
8090
return self._chunk_size_cells
@@ -87,6 +97,10 @@ def chunk_size_genes(self) -> int:
8797
def cast_dtype(self) -> str:
8898
return self._cast_dtype
8999

100+
@property
101+
def sample_description(self) -> pd.DataFrame:
102+
return self._sample_description
103+
90104
@property
91105
def design_loc(self) -> Union[np.ndarray, dask.array.core.Array]:
92106
"""location design matrix"""
@@ -356,7 +370,7 @@ def generate_params(
356370
if rand_fn_scale is None:
357371
rand_fn_scale = rand_fn
358372

359-
_design_loc, _design_scale, _ = generate_sample_description(**kwargs)
373+
_design_loc, _design_scale, _sample_description = generate_sample_description(**kwargs)
360374

361375
self._theta_location = np.concatenate(
362376
[
@@ -366,8 +380,9 @@ def generate_params(
366380
axis=0,
367381
)
368382
self._theta_scale = np.concatenate([rand_fn_scale((_design_scale.shape[1], n_vars))], axis=0)
383+
self._sample_description = _sample_description
369384

370-
return _design_loc, _design_scale
385+
return _design_loc, _design_scale, _sample_description
371386

372387
def generate_artificial_data(
373388
self,
@@ -379,6 +394,8 @@ def generate_artificial_data(
379394
shuffle_assignments: bool = False,
380395
sparse: bool = False,
381396
as_dask: bool = True,
397+
theta_location_setter: Optional[Callable] = None,
398+
theta_scale_setter: Optional[Callable] = None,
382399
**kwargs,
383400
):
384401
"""
@@ -391,9 +408,11 @@ def generate_artificial_data(
391408
:param shuffle_assignments: Depcreated. Does not do anything.
392409
:param sparse: If True, the simulated data matrix is sparse.
393410
:param as_dask: If True, use dask.
411+
:param theta_location_setter: Override for parameter after generate_params, should return the parameter
412+
:param theta_scale_setter: Override for parameter after generate_params, should return the parameter
394413
:param kwargs: Additional kwargs passed to generate_params.
395414
"""
396-
_design_loc, _design_scale = self.generate_params(
415+
_design_loc, _design_scale, _ = self.generate_params(
397416
n_vars=n_vars,
398417
num_observations=n_obs,
399418
num_conditions=num_conditions,
@@ -402,6 +421,10 @@ def generate_artificial_data(
402421
shuffle_assignments=shuffle_assignments,
403422
**kwargs,
404423
)
424+
if theta_location_setter is not None:
425+
self._theta_location = theta_location_setter(self._theta_location)
426+
if theta_scale_setter is not None:
427+
self._theta_scale = theta_scale_setter(self._theta_scale)
405428

406429
# we need to do this explicitly here in order to generate data
407430
self._constraints_loc = np.identity(n=_design_loc.shape[1])
@@ -413,8 +436,15 @@ def generate_artificial_data(
413436
data_matrix = self.generate_data().astype(self.cast_dtype)
414437
if sparse:
415438
data_matrix = scipy.sparse.csr_matrix(data_matrix)
416-
417-
input_data = InputDataGLM(data=data_matrix, design_loc=_design_loc, design_scale=_design_scale, as_dask=as_dask)
439+
# generate random gene/feature names
440+
feature_names = "".join("feature_" + str(i) for i in range(n_vars))
441+
input_data = InputDataGLM(
442+
data=data_matrix,
443+
design_loc=_design_loc,
444+
design_scale=_design_scale,
445+
as_dask=as_dask,
446+
feature_names=feature_names,
447+
)
418448
self.extract_input_data(input_data)
419449

420450
@abc.abstractmethod

batchglm/models/base_glm/utils.py

Lines changed: 15 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,22 @@
77
import pandas as pd
88
import patsy
99
import scipy.sparse
10+
import sparse
1011

1112
from .external import groupwise_solve_lm
1213

1314
logger = logging.getLogger("batchglm")
1415

1516

17+
def densify(arr):
18+
if isinstance(arr, dask.array.core.Array):
19+
arr = arr.compute()
20+
if isinstance(arr, sparse.COO) or isinstance(arr, scipy.sparse.csr_matrix):
21+
return arr.todense()
22+
else:
23+
return arr
24+
25+
1626
def generate_sample_description(
1727
num_observations: int,
1828
num_conditions: int,
@@ -61,87 +71,6 @@ def generate_sample_description(
6171
return sim_design_loc, sim_design_scale, sample_description
6272

6373

64-
def parse_design(
65-
design_matrix: Union[pd.DataFrame, patsy.design_info.DesignMatrix, dask.array.core.Array, np.ndarray],
66-
param_names: List[str] = None,
67-
) -> Tuple[np.ndarray, List[str]]:
68-
r"""
69-
Parser for design matrices.
70-
71-
:param design_matrix: Design matrix.
72-
:param param_names:
73-
Optional coefficient names for design_matrix.
74-
Ignored if design_matrix is pd.DataFrame or patsy.design_info.DesignMatrix.
75-
:return: Tuple[np.ndarray, List[str]] containing the design matrix and the parameter names.
76-
:raise AssertionError: if the type of design_matrix is not understood.
77-
:raise AssertionError: if length of provided param_names is not equal to number of coefficients in design_matrix.
78-
:raise ValueError: if param_names is None when type of design_matrix is numpy.ndarray or dask.array.core.Array.
79-
"""
80-
if isinstance(design_matrix, (pd.DataFrame, patsy.design_info.DesignMatrix)) and param_names is not None:
81-
logger.warning(f"The provided param_names are ignored as the design matrix is of type {type(design_matrix)}.")
82-
83-
if isinstance(design_matrix, patsy.design_info.DesignMatrix):
84-
dmat = np.asarray(design_matrix)
85-
params = design_matrix.design_info.column_names
86-
elif isinstance(design_matrix, pd.DataFrame):
87-
dmat = np.asarray(design_matrix)
88-
params = design_matrix.columns.tolist()
89-
elif isinstance(design_matrix, dask.array.core.Array):
90-
dmat = design_matrix.compute()
91-
params = param_names
92-
elif isinstance(design_matrix, np.ndarray):
93-
dmat = design_matrix
94-
params = param_names
95-
else:
96-
raise AssertionError(f"Datatype for design_matrix not understood: {type(design_matrix)}")
97-
if params is None:
98-
raise ValueError("Provide names when passing design_matrix as np.ndarray or dask.array.core.Array!")
99-
assert len(params) == dmat.shape[1], (
100-
"Length of provided param_names is not equal to " "number of coefficients in design_matrix."
101-
)
102-
return dmat, params
103-
104-
105-
def parse_constraints(
106-
dmat: np.ndarray,
107-
dmat_par_names: List[str],
108-
constraints: Optional[Union[np.ndarray, dask.array.core.Array]] = None,
109-
constraint_par_names: Optional[List[str]] = None,
110-
) -> Tuple[np.ndarray, List[str]]:
111-
r"""
112-
Parser for constraint matrices.
113-
114-
:param dmat: Design matrix.
115-
:param constraints: Constraint matrix.
116-
:param constraint_par_names: Optional coefficient names for constraints.
117-
:return: Tuple[np.ndarray, List[str]] containing the constraint matrix and the parameter names.
118-
:raise AssertionError: if the type of given design / contraint matrix is not np.ndarray or dask.array.core.Array.
119-
"""
120-
assert isinstance(dmat, np.ndarray), "dmat must be provided as np.ndarray."
121-
if constraints is None:
122-
constraints = np.identity(n=dmat.shape[1])
123-
constraint_params = dmat_par_names
124-
else:
125-
if isinstance(constraints, dask.array.core.Array):
126-
constraints = constraints.compute()
127-
assert isinstance(constraints, np.ndarray), "contraints must be np.ndarray or dask.array.core.Array."
128-
# Cannot use all parameter names if constraint matrix is not identity: Make up new ones.
129-
# Use variable names that can be mapped (unconstrained).
130-
if constraint_par_names is not None:
131-
assert len(constraint_params) == len(constraint_par_names)
132-
constraint_params = constraint_par_names
133-
else:
134-
constraint_params = [
135-
"var_" + str(i)
136-
if np.sum(constraints[:, i] != 0) > 1
137-
else dmat_par_names[np.where(constraints[:, i] != 0)[0][0]]
138-
for i in range(constraints.shape[1])
139-
]
140-
assert constraints.shape[0] == dmat.shape[1], "constraint dimension mismatch"
141-
142-
return constraints, constraint_params
143-
144-
14574
def closedform_glm_mean(
14675
x: Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array],
14776
dmat: Union[np.ndarray, dask.array.core.Array],
@@ -168,8 +97,9 @@ def closedform_glm_mean(
16897
x = np.divide(x, size_factors)
16998

17099
def apply_fun(grouping):
100+
171101
groupwise_means = np.asarray(
172-
np.vstack([np.mean(x[np.where(grouping == g)[0], :], axis=0) for g in np.unique(grouping)])
102+
np.vstack([np.mean(densify(x[np.where(grouping == g)[0], :]), axis=0) for g in np.unique(grouping)])
173103
)
174104
if link_fn is None:
175105
return groupwise_means
@@ -218,7 +148,7 @@ def apply_fun(grouping):
218148
# Calculate group-wise means if not supplied. These are required for variance and MME computation.
219149
if provided_groupwise_means is None:
220150
gw_means = np.asarray(
221-
np.vstack([np.mean(x[np.where(grouping == g)[0], :], axis=0) for g in np.unique(grouping)])
151+
np.vstack([np.mean(densify(x[np.where(grouping == g)[0], :]), axis=0) for g in np.unique(grouping)])
222152
)
223153
else:
224154
gw_means = provided_groupwise_means
@@ -228,14 +158,14 @@ def apply_fun(grouping):
228158
expect_xsq = np.asarray(
229159
np.vstack(
230160
[
231-
np.asarray(np.mean(x[np.where(grouping == g)[0], :].power(2), axis=0))
161+
np.asarray(np.mean(densify(x[np.where(grouping == g)[0], :]).power(2), axis=0))
232162
for g in np.unique(grouping)
233163
]
234164
)
235165
)
236166
else:
237167
expect_xsq = np.vstack(
238-
[np.mean(np.square(x[np.where(grouping == g)[0], :]), axis=0) for g in np.unique(grouping)]
168+
[np.mean(np.square(densify(x[np.where(grouping == g)[0], :])), axis=0) for g in np.unique(grouping)]
239169
)
240170
expect_x_sq = np.square(gw_means)
241171
variance = expect_xsq - expect_x_sq

batchglm/models/glm_beta/external.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
import batchglm.utils.data as data_utils
22
from batchglm import pkg_constants
3-
from batchglm.models.base_glm import _ModelGLM, closedform_glm_mean, closedform_glm_scale
3+
from batchglm.models.base_glm import ModelGLM, closedform_glm_mean, closedform_glm_scale
44
from batchglm.utils.linalg import groupwise_solve_lm

batchglm/models/glm_beta/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import dask
55
import numpy as np
66

7-
from .external import _ModelGLM
7+
from .external import ModelGLM
88

99

10-
class Model(_ModelGLM, metaclass=abc.ABCMeta):
10+
class Model(ModelGLM, metaclass=abc.ABCMeta):
1111
"""
1212
Generalized Linear Model (GLM) with beta distributed noise, logit link for location and log link for scale.
1313
"""

batchglm/models/glm_nb/external.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
import batchglm.utils.data as data_utils
22
from batchglm import pkg_constants
3-
from batchglm.models.base_glm import _ModelGLM, closedform_glm_mean, closedform_glm_scale
3+
from batchglm.models.base_glm import ModelGLM, closedform_glm_mean, closedform_glm_scale
44
from batchglm.utils.linalg import groupwise_solve_lm

batchglm/models/glm_nb/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import dask.array
55
import numpy as np
66

7-
from .external import _ModelGLM
7+
from .external import ModelGLM
88

99

10-
class Model(_ModelGLM, metaclass=abc.ABCMeta):
10+
class Model(ModelGLM, metaclass=abc.ABCMeta):
1111
"""
1212
Generalized Linear Model (GLM) with negative binomial noise.
1313
"""

batchglm/models/glm_norm/external.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
import batchglm.utils.data as data_utils
22
from batchglm import pkg_constants
3-
from batchglm.models.base_glm import _ModelGLM, closedform_glm_mean, closedform_glm_scale
3+
from batchglm.models.base_glm import ModelGLM, closedform_glm_mean, closedform_glm_scale
44
from batchglm.utils.linalg import groupwise_solve_lm

batchglm/models/glm_norm/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import dask
55
import numpy as np
66

7-
from .external import _ModelGLM
7+
from .external import ModelGLM
88

99

10-
class Model(_ModelGLM, metaclass=abc.ABCMeta):
10+
class Model(ModelGLM, metaclass=abc.ABCMeta):
1111

1212
"""Generalized Linear Model (GLM) with normal noise."""
1313

batchglm/models/glm_norm/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def init_par(model, init_location: str, init_scale: str) -> Tuple[np.ndarray, np
6161
&= D \cdot x' = f^{-1}(\theta)
6262
$$
6363
"""
64-
6564
groupwise_means = None
6665

6766
init_location_str = init_location.lower()
@@ -79,7 +78,7 @@ def init_par(model, init_location: str, init_scale: str) -> Tuple[np.ndarray, np
7978
elif init_location_str == "standard":
8079
overall_means = np.mean(model.x, axis=0) # directly calculate the mean
8180
init_theta_location = np.zeros([model.num_loc_params, model.num_features])
82-
init_theta_location[0, :] = np.log(overall_means)
81+
init_theta_location[0, :] = overall_means # identity linked.
8382
else:
8483
raise ValueError("init_location string %s not recognized" % init_location)
8584

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
import batchglm.utils.data as data_utils
22
from batchglm import pkg_constants
3-
from batchglm.models.base_glm import _ModelGLM, closedform_glm_mean, closedform_glm_scale
3+
from batchglm.models.base_glm import ModelGLM, closedform_glm_mean, closedform_glm_scale
44
from batchglm.utils.linalg import groupwise_solve_lm

0 commit comments

Comments
 (0)