Skip to content

Commit

Permalink
comment update
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Feb 14, 2024
1 parent f2e5434 commit eb6ee15
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 21 deletions.
75 changes: 59 additions & 16 deletions src/skmatter/neighbors/_sparsekde.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,76 @@
pairwise_euclidean_distances,
pairwise_mahalanobis_distances,
)
from ..utils._sparsekde import *
from ..utils._sparsekde import (
NearestGridAssigner,
GaussianMixtureModel,
covariance,
local_population,
effdim,
oas,
quick_shift,
rij,
)

DIST_METRICS = {
"periodic_euclidean": pairwise_euclidean_distances,
}


class SparseKDE(BaseEstimator):
"""A sparse implementation of the Kernel Density Estimation.
The bandwidth will be optimized per sample.
- We only support Gaussian kernels. (Check
howe hard others are and make it paramater later)
- We only support Gaussian kernels.
- Implement a sklean like metric: named periodic euclidian. and make metric parameter
distance.
Parameters
----------
descriptors: Descriptors of the system where you want to build a sparse KDE.
weights: Weights of the descriptors.
kernel : {'gaussian'}, default='gaussian'
The kernel to use. Currentlty only one. Check how sklearn kernels are defined. Try to reuse
The kernel to use. Currentlty only one.
metric : str, default='periodic_euclidean'
The metric to use. Currently only one.
metric_params : dict, default=None
Additional parameters to be passed to the use of
metric. i.e. the cell dimension for `periodic_euclidean`
qs : Scaling factor used during the QS clustering.
gs : The neighbor shell for gabriel shift.
thrpcl : Clusters with a pk loewr than this value are merged with the NN.
fspread : The fractional variance for bandwidth estimation.
fpoints : The fractional number of grid points.
nmsopt : The number of mean-shift refinement steps.
Examples
--------
>>> import numpy as np
>>> from skmatter.neighbors import SparseKDE
>>> from skmatter.feature_selection import FPS
>>> np.random.seed(0)
>>> n_samples = 10000
>>> samples = np.concatenate(
>>> [np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], n_samples),
>>> np.random.multivariate_normal([4, 4], [[1, 0.5], [0.5, 0.5]], n_samples)]
>>> )
>>> selector = FPS(n_to_select=int(np.sqrt(2 * n_samples)))
>>> result = selector.fit_transform(samples.T).T
>>> estimator = SparseKDE(samples, None, fpoints=0.5, qs=0.85)
>>> estimator.fit(result)
>>> estimator.score(result)
2.7671739267690363
>>> estimator.sample()
array([[3.32383366, 3.51779084]])
"""

def __init__(
self,
descriptors: np.ndarray,
weights: np.ndarray,
kernel: str = "gaussian",
metric: str = 'periodic_euclidean',
metric: str = "periodic_euclidean",
metric_params: dict = {},
qs: float = 1.0,
gs: int = -1,
Expand All @@ -48,7 +89,7 @@ def __init__(
nmsopt: int = 0,
):
self.kernel = kernel
self.metric = metric
self.metric = DIST_METRICS[metric]
self.metric_params = metric_params
self.cell = metric_params["cell"] if "cell" in metric_params else None
self.descriptors = descriptors
Expand Down Expand Up @@ -104,10 +145,10 @@ def fit(self, X, y=None, sample_weight=None):
# else:
# sample_weight = np.ones(X.shape[0], dtype=np.float64) / X.shape[0]
self.kdecut2 = 9 * (np.sqrt(X.shape[1]) + 1) ** 2
grid_dist_mat = pairwise_euclidean_distances(X, X, squared=True, cell=self.cell)
grid_dist_mat = self.metric(X, X, squared=True, cell=self.cell)
np.fill_diagonal(grid_dist_mat, np.inf)
min_grid_dist = np.min(grid_dist_mat, axis=1)
grid_npoints, grid_neighbour, sample_labels_, sample_weight = (
_, grid_neighbour, sample_labels_, sample_weight = (
self._assign_descriptors_to_grids(X)
)
h_invs, normkernels, qscut2 = self._computes_localization(
Expand All @@ -117,10 +158,10 @@ def fit(self, X, y=None, sample_weight=None):
X, sample_weight, h_invs, normkernels, grid_neighbour
)
normpks = LSE(probs)
cluster_centers, idxroot = quick_shift(
probs, grid_dist_mat, qscut2, self.gs
cluster_centers, idxroot = quick_shift(probs, grid_dist_mat, qscut2, self.gs)
cluster_centers, idxroot = self._post_process(
X, cluster_centers, idxroot, probs, normpks
)
cluster_centers, idxroot = self._post_process(X, cluster_centers, idxroot, probs, normpks)
self.cluster_weight, self.cluster_mean, self.cluster_cov = (
self._generate_probability_model(
X,
Expand All @@ -136,9 +177,9 @@ def fit(self, X, y=None, sample_weight=None):
self.model = GaussianMixtureModel(
self.cluster_weight, self.cluster_mean, self.cluster_cov, period=self.cell
)
self.__sklearn_is_fitted__ = True
self.fitted_ = True

return self, probs
return self

def score_samples(self, X):
"""Compute the log-likelihood of each sample under the model.
Expand Down Expand Up @@ -219,7 +260,7 @@ def sample(self, n_samples=1, random_state=None):

def _assign_descriptors_to_grids(self, X):

assigner = NearestGridAssigner(self.cell)
assigner = NearestGridAssigner(self.metric, self.cell)
assigner.fit(X)
labels = assigner.predict(self.descriptors, sample_weight=self.weights)
grid_npoints = assigner.grid_npoints
Expand Down Expand Up @@ -353,7 +394,9 @@ def _computes_kernel_density_estimation(
lnk = -0.5 * (normkernel[j] + dummd1) + np.log(sample_weights[j])
prob[i] = LSE([prob[i], lnk])
else:
neighbours = neighbour[j][np.any(self.descriptors[neighbour[j]] != X[i], axis=1)]
neighbours = neighbour[j][
np.any(self.descriptors[neighbour[j]] != X[i], axis=1)
]
if neighbours.size == 0:
continue
dummd1s = pairwise_mahalanobis_distances(
Expand Down Expand Up @@ -401,7 +444,7 @@ def getidmax(v1: np.ndarray, probs: np.ndarray, clusterid: int):
for j in range(nk):
if to_merge[k]:
continue
dummd2 = pairwise_euclidean_distances(
dummd2 = self.metric(
X[idxroot[dummd1yi1]], X[idxroot[j]], cell=self.cell
)
if dummd2 < dummd1:
Expand Down
10 changes: 5 additions & 5 deletions src/skmatter/utils/_sparsekde.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,22 @@
from tqdm import tqdm

import numpy as np
from ..metrics.pairwise import pairwise_euclidean_distances


class NearestGridAssigner:
"""NearestGridAssigner Class
Assign descriptor to its nearest grid.
Assign descriptor to its nearest grid. This is an axulirary class.
Args:
cell (np.ndarray): An array of periods for each dimension of the grid.
exclude_grid (bool): Whether to exclude the grid itself from neighbor lists."""

def __init__(self, cell: Optional[np.ndarray] = None, exclude_grid: bool = True) -> None:
def __init__(self, metric, cell: Optional[np.ndarray] = None, exclude_grid: bool = True) -> None:

self.labels_ = None
self.metric = metric
self.cell = cell
self.exclude_grid = exclude_grid
self._distance = pairwise_euclidean_distances
self.grid_pos = None
self.grid_npoints = None
self.grid_weight = None
Expand Down Expand Up @@ -49,7 +48,7 @@ def predict(
for i, point in tqdm(
enumerate(X), desc="Assigning samples to grids...", total=len(X)
):
descriptor2grid = self._distance(
descriptor2grid = self.metric(
X=point.reshape(1, -1), Y=self.grid_pos, cell=self.cell
)
self.labels_.append(np.argmin(descriptor2grid))
Expand All @@ -69,6 +68,7 @@ class GaussianMixtureModel:
means: np.ndarray
covariances: np.ndarray
period: Optional[np.ndarray] = None
"""A simple class for Gaussian Mixture Model. This is an axulirary class."""

def __post_init__(self):
self.dimension = self.means.shape[1]
Expand Down

0 comments on commit eb6ee15

Please sign in to comment.