Skip to content

Commit 607a642

Browse files
committed
First attempt at emplemnting OTOS. Still needs work
1 parent d325cdf commit 607a642

File tree

3 files changed

+71
-7
lines changed

3 files changed

+71
-7
lines changed

pyloras/_otos.py

+67-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import ot
21
from imblearn.over_sampling.base import BaseOverSampler
32
from imblearn.utils import Substitution
43
from imblearn.utils._docstring import (
54
_random_state_docstring,
65
_n_jobs_docstring,
76
)
7+
from sklearn.svm import LinearSVC
88
import numpy as np
99

1010
from ._common import check_random_state, safe_random_state
@@ -20,16 +20,76 @@ def __init__(
2020
self,
2121
*,
2222
sampling_strategy="auto",
23-
svm_regularization=1.0,
24-
ot_regularization=1.0,
23+
svc_reg=1.0,
24+
ot_reg=1.0,
2525
tradeoff=1.0,
2626
random_state=None,
27+
max_iter=100,
2728
):
2829
super().__init__(sampling_strategy=sampling_strategy)
29-
self.svm_regularization = svm_regularization
30-
self.ot_regularization = ot_regularization
30+
self.svc_reg = svc_reg
31+
self.ot_reg = ot_reg
3132
self.tradeoff = tradeoff
3233
self.random_state = random_state
34+
self.max_iter = max_iter
3335

34-
def fit_resample(self, X, y):
35-
return X, y
36+
def _fit_resample(self, X, y):
37+
import ot
38+
random_state = check_random_state(self.random_state)
39+
X_res = [X.copy()]
40+
y_res = [y.copy()]
41+
svc = LinearSVC(
42+
loss="hinge", C=self.svc_reg, random_state=safe_random_state(random_state)
43+
)
44+
for minority_class, samples_to_make in self.sampling_strategy_.items():
45+
if samples_to_make == 0:
46+
continue
47+
X_p = X[y == minority_class]
48+
X_n = X[y != minority_class]
49+
n_p = X_p.shape[0]
50+
n_n = X_n.shape[0]
51+
n_r = samples_to_make
52+
one_r = np.ones((n_r, 1))
53+
one_n = np.ones((n_n, 1))
54+
# set initial distribution for mu_r and mu_p
55+
mu_r = np.asarray([1.0 / n_r] * n_r)
56+
mu_p = np.asarray([1.0 / n_p] * n_p)
57+
T = mu_r[:, None] @ mu_p[:, None].T
58+
# manufactor a binary classification problem
59+
_y = np.empty_like(y)
60+
_y[y == minority_class] = 0
61+
_y[y != minority_class] = 1
62+
svc.fit(X, _y)
63+
w = svc.coef_.T
64+
65+
hingelosses = np.concatenate(
66+
[
67+
np.atleast_1d(max(1 - y_i * svc.coef_ @ x_row, 0.0))
68+
for y_i, x_row in zip(y[y == minority_class], X_p)
69+
]
70+
)
71+
mu_p = np.exp(hingelosses)
72+
mu_p /= mu_p.sum()
73+
74+
D_r = np.diag(1 / mu_r)
75+
X_r = D_r @ T @ X_p
76+
# C_p = np.apply_along_axis(c_row, axis=-1, arr=X_r)
77+
C_p = np.asarray(
78+
[
79+
[np.linalg.norm(x_row - row) for row in X_p]
80+
for x_row in X_r
81+
]
82+
)
83+
wwT = w @ w.T
84+
Theta = self.tradeoff * C_p.T - X_p @ np.kron(one_r.T, wwT @ X_n.T @ one_n + n_n * w) @ D_r
85+
Phi = X_p @ wwT @ X_p.T
86+
Psi = D_r.T @ D_r
87+
88+
print(mu_r.sum(), mu_p.sum())
89+
for _ in range(self.max_iter):
90+
transport_cost = Theta.T + n_n * Psi @ T @ Phi
91+
T = ot.sinkhorn(mu_r, mu_p, transport_cost, self.ot_reg)
92+
X_res.append(D_r @ T @ X_p)
93+
y_res.append([minority_class] * n_r)
94+
print(z)
95+
return np.concatenate(X_res), np.concatenate(y_res)

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ classifiers = [
2727
"Operating System :: MacOS :: MacOS X",
2828
]
2929

30+
[project.optional-dependencies]
31+
otos = ["POT"]
32+
3033
[project.urls]
3134
source = "https://github.com/zoj613/pyloras"
3235
tracker = "https://github.com/zoj613/pyloras/issues"

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
build==0.10.0
33
imbalanced-learn==0.10.1
44
numpy==1.23.2
5+
POT==0.9.0
56
pytest==7.3.1
67
pytest-cov==4.0.0

0 commit comments

Comments
 (0)