Skip to content

Commit 1b970a5

Browse files
committed
no more need for FeaturePredecessor with empty (default) value
1 parent 3a55c3f commit 1b970a5

File tree

9 files changed

+145
-8105
lines changed

9 files changed

+145
-8105
lines changed

notebooks/scratch_features.ipynb

+57-63
Large diffs are not rendered by default.

notebooks/scratch_features2.ipynb

+77-8,000
Large diffs are not rendered by default.

src/eegdash/features/extractors.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def __init__(
6565
def _validate_execution_tree(self, feature_extractors):
6666
for fname, f in feature_extractors.items():
6767
f = _get_underlying_func(f)
68-
assert type(self) in f.parent_extractor_type
68+
pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
69+
assert type(self) in pe_type
6970
return feature_extractors
7071

7172
def _check_is_fitable(self, feature_extractors):

src/eegdash/features/feature_bank/complexity.py

-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def _channel_app_samp_entropy_counts(x, m, r, l):
2929
return kdtree.query_radius(x_emb, r, count_only=True)
3030

3131

32-
@FeaturePredecessor()
3332
class EntropyFeatureExtractor(FeatureExtractor):
3433
def preprocess(self, x, m=2, r=0.2, l=1):
3534
rr = r * x.std(axis=-1)
@@ -57,7 +56,6 @@ def complexity_sample_entropy(counts_m, counts_mp1):
5756
return -np.log(A / B)
5857

5958

60-
@FeaturePredecessor()
6159
@univariate_feature
6260
def complexity_svd_entropy(x, m=10, tau=1):
6361
x_emb = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // tau, m))
@@ -68,7 +66,6 @@ def complexity_svd_entropy(x, m=10, tau=1):
6866
return -np.sum(s * np.log(s), axis=-1)
6967

7068

71-
@FeaturePredecessor()
7269
@univariate_feature
7370
@nb.njit(cache=True, fastmath=True)
7471
def complexity_lempel_ziv(x, threshold=None):

src/eegdash/features/feature_bank/connectivity.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from itertools import chain
22
import numpy as np
3-
from scipy.signal import csd, coherence
3+
from scipy.signal import csd
44

55
from ..extractors import FeatureExtractor, BivariateFeature
66
from ..decorators import FeaturePredecessor, bivariate_feature
@@ -14,7 +14,6 @@
1414
]
1515

1616

17-
@FeaturePredecessor()
1817
class CoherenceFeatureExtractor(FeatureExtractor):
1918
def preprocess(self, x, **kwargs):
2019
f_min = kwargs.pop("f_min") if "f_min" in kwargs else None

src/eegdash/features/feature_bank/csp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import scipy.linalg
55

66
from ..extractors import FitableFeature
7-
from ..decorators import FeaturePredecessor, multivariate_feature
7+
from ..decorators import multivariate_feature
88

99

1010
__all__ = [
@@ -22,7 +22,6 @@ def _update_mean_cov(count, mean, cov, x_count, x_mean, x_cov):
2222
cov[:] -= np.outer(mean, mean)
2323

2424

25-
@FeaturePredecessor()
2625
@multivariate_feature
2726
class CommonSpatialPattern(FitableFeature):
2827
def __init__(self):
@@ -77,6 +76,7 @@ def fit(self):
7776
for l in range(len(self._labels)):
7877
self._covs[l] *= self._counts[l] / (self._counts[1] - 1)
7978
l, w = scipy.linalg.eig(self._covs[0], self._covs[0] + self._covs[1])
79+
l = l.real
8080
ind = l > 0
8181
l, w = l[ind], w[:, ind]
8282
ord = np.abs(l - 0.5).argsort()[::-1]

src/eegdash/features/feature_bank/dimensionality.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numba as nb
33
from scipy import special
44

5-
from ..decorators import FeaturePredecessor, univariate_feature
5+
from ..decorators import univariate_feature
66
from .signal import signal_zero_crossings
77

88

@@ -15,7 +15,6 @@
1515
]
1616

1717

18-
@FeaturePredecessor()
1918
@univariate_feature
2019
@nb.njit(cache=True, fastmath=True)
2120
def dimensionality_higuchi_fractal_dim(x, k_max=10, eps=1e-7):
@@ -34,15 +33,13 @@ def dimensionality_higuchi_fractal_dim(x, k_max=10, eps=1e-7):
3433
return hfd
3534

3635

37-
@FeaturePredecessor()
3836
@univariate_feature
3937
def dimensionality_petrosian_fractal_dim(x):
4038
nd = signal_zero_crossings(np.diff(x, axis=-1))
4139
log_n = np.log(x.shape[-1])
4240
return log_n / (np.log(nd) + log_n)
4341

4442

45-
@FeaturePredecessor()
4643
@univariate_feature
4744
def dimensionality_katz_fractal_dim(x):
4845
dists = np.abs(np.diff(x, axis=-1))
@@ -53,7 +50,6 @@ def dimensionality_katz_fractal_dim(x):
5350
return log_n / (np.log(d / L) + log_n)
5451

5552

56-
@FeaturePredecessor()
5753
@univariate_feature
5854
@nb.njit(cache=True, fastmath=True)
5955
def _hurst_exp(x, ns, a, gamma_ratios, log_n):
@@ -80,7 +76,6 @@ def _hurst_exp(x, ns, a, gamma_ratios, log_n):
8076
return h
8177

8278

83-
@FeaturePredecessor()
8479
@univariate_feature
8580
def dimensionality_hurst_exp(x):
8681
ns = np.unique(np.power(2, np.arange(2, np.log2(x.shape[-1]) - 1)).astype(int))
@@ -94,7 +89,6 @@ def dimensionality_hurst_exp(x):
9489
return _hurst_exp(x, ns, a, gamma_ratios, log_n)
9590

9691

97-
@FeaturePredecessor()
9892
@univariate_feature
9993
@nb.njit(cache=True, fastmath=True)
10094
def dimensionality_detrended_fluctuation_analysis(x):

src/eegdash/features/feature_bank/signal.py

+5-26
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
import numpy as np
33
from scipy import stats
44

5-
from ..extractors import FeatureExtractor
6-
from ..decorators import FeaturePredecessor, univariate_feature
5+
from ..decorators import univariate_feature
76

87

98
__all__ = [
10-
"HjorthFeatureExtractor",
119
"signal_mean",
1210
"signal_variance",
1311
"signal_skewness",
@@ -25,61 +23,51 @@
2523
]
2624

2725

28-
@FeaturePredecessor()
2926
@univariate_feature
3027
def signal_mean(x):
3128
return x.mean(axis=-1)
3229

3330

34-
@FeaturePredecessor()
3531
@univariate_feature
3632
def signal_variance(x, **kwargs):
3733
return x.var(axis=-1, **kwargs)
3834

3935

40-
@FeaturePredecessor()
4136
@univariate_feature
4237
def signal_std(x, **kwargs):
4338
return x.std(axis=-1, **kwargs)
4439

4540

46-
@FeaturePredecessor()
4741
@univariate_feature
4842
def signal_skewness(x, **kwargs):
4943
return stats.skew(x, axis=x.ndim - 1, **kwargs)
5044

5145

52-
@FeaturePredecessor()
5346
@univariate_feature
5447
def signal_kurtosis(x, **kwargs):
5548
return stats.kurtosis(x, axis=x.ndim - 1, **kwargs)
5649

5750

58-
@FeaturePredecessor()
5951
@univariate_feature
6052
def signal_root_mean_square(x):
6153
return np.sqrt(np.power(x, 2).mean(axis=-1))
6254

6355

64-
@FeaturePredecessor()
6556
@univariate_feature
6657
def signal_peak_to_peak(x, **kwargs):
6758
return np.ptp(x, axis=-1, **kwargs)
6859

6960

70-
@FeaturePredecessor()
7161
@univariate_feature
7262
def signal_quantile(x, q: numbers.Number = 0.5, **kwargs):
7363
return np.quantile(x, q=q, axis=-1, **kwargs)
7464

7565

76-
@FeaturePredecessor()
7766
@univariate_feature
7867
def signal_line_length(x):
7968
return np.abs(np.diff(x, axis=-1)).mean(axis=-1)
8069

8170

82-
@FeaturePredecessor()
8371
@univariate_feature
8472
def signal_zero_crossings(x, threshold=1e-15):
8573
zero_ind = np.logical_and(x > -threshold, x < threshold)
@@ -90,25 +78,16 @@ def signal_zero_crossings(x, threshold=1e-15):
9078
return zero_cross
9179

9280

93-
@FeaturePredecessor()
94-
class HjorthFeatureExtractor(FeatureExtractor):
95-
def preprocess(self, x):
96-
return (x, np.diff(x, axis=-1), x.std(axis=-1))
97-
98-
99-
@FeaturePredecessor(HjorthFeatureExtractor)
10081
@univariate_feature
101-
def signal_hjorth_mobility(x, dx, x_std):
102-
return dx.std(axis=-1) / x_std
82+
def signal_hjorth_mobility(x):
83+
return np.diff(x, axis=-1).std(axis=-1) / x.std(axis=-1)
10384

10485

105-
@FeaturePredecessor(HjorthFeatureExtractor)
10686
@univariate_feature
107-
def signal_hjorth_complexity(x, dx, x_std):
108-
return np.diff(dx, axis=-1).std(axis=-1) / x_std
87+
def signal_hjorth_complexity(x):
88+
return np.diff(x, 2, axis=-1).std(axis=-1) / x.std(axis=-1)
10989

11090

111-
@FeaturePredecessor()
11291
@univariate_feature
11392
def signal_decorrelation_time(x, fs=1):
11493
f = np.fft.fft(x - x.mean(axis=-1, keepdims=True), axis=-1)

src/eegdash/features/feature_bank/spectral.py

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
]
2323

2424

25-
@FeaturePredecessor()
2625
class SpectralFeatureExtractor(FeatureExtractor):
2726
def preprocess(self, x, **kwargs):
2827
f_min = kwargs.pop("f_min") if "f_min" in kwargs else None

0 commit comments

Comments
 (0)