Skip to content

Commit d332476

Browse files
committed
Refactor AdaBoost code
* Refactor comments to be more compliant with PEP 8 * Removed inconsistent / legacy / trivial comments * Move `_plot` to static in AdaBoostClassifier * Fix unused `title` parameter in plot_boundary * Rewrite accuracy as one-liner in main * Remove `cmap` global variable in `utils.py` * Refactor variables to improve code readability * Added docs and comments * Make black lines homogeneous
1 parent 594c87d commit d332476

File tree

3 files changed

+109
-118
lines changed

3 files changed

+109
-118
lines changed

lab/boosting/code/boosting.py

+63-76
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,36 @@
1+
import warnings
2+
13
import matplotlib.pyplot as plt
24
import numpy as np
3-
4-
from utils import cmap
5+
from numpy.random import choice
56

67

78
class WeakClassifier:
89
"""
9-
Function that models a WeakClassifier
10+
Class that models a WeakClassifier
1011
"""
11-
1212
def __init__(self):
13-
14-
# initialize a few stuff
1513
self._dim = None
1614
self._threshold = None
1715
self._label_above_split = None
1816

1917
def fit(self, X: np.ndarray, Y: np.ndarray):
2018

21-
n, d = X.shape
22-
possible_labels = np.unique(Y)
23-
24-
# select random feature (see np.random.choice)
25-
self._dim = np.random.choice(a=range(0, d))
19+
# Select random feature (see np.random.choice)
20+
_, n_feats = X.shape
21+
self._dim = choice(a=range(0, n_feats))
2622

27-
# select random split (see np.random.uniform)
28-
M, m = np.max(X[:, self._dim]), np.min(X[:, self._dim])
29-
self._threshold = np.random.uniform(low=m, high=M)
23+
# Select random split threshold
24+
feat_min = np.min(X[:, self._dim])
25+
feat_max = np.max(X[:, self._dim])
26+
self._threshold = np.random.uniform(low=feat_min, high=feat_max)
3027

31-
# select random verse (see np.random.choice)
32-
self._label_above_split = np.random.choice(a=possible_labels)
28+
# Select random verse
29+
possible_labels = np.unique(Y)
30+
self._label_above_split = choice(a=possible_labels)
3331

3432
def predict(self, X: np.ndarray):
35-
36-
num_samples = X.shape[0]
37-
y_pred = np.zeros(shape=num_samples)
33+
y_pred = np.zeros(shape=X.shape[0])
3834
y_pred[X[:, self._dim] >= self._threshold] = self._label_above_split
3935
y_pred[X[:, self._dim] < self._threshold] = -1 * self._label_above_split
4036

@@ -43,20 +39,17 @@ def predict(self, X: np.ndarray):
4339

4440
class AdaBoostClassifier:
4541
"""
46-
Function that models a Adaboost classifier
42+
Class encapsulating AdaBoost classifier
4743
"""
48-
4944
def __init__(self, n_learners: int, n_max_trials: int = 200):
5045
"""
51-
Model constructor
46+
Initialize an AdaBoost classifier.
5247
5348
Parameters
5449
----------
5550
n_learners: int
56-
number of weak classifiers.
51+
Number of weak classifiers.
5752
"""
58-
59-
# initialize a few stuff
6053
self.n_learners = n_learners
6154
self.learners = []
6255
self.alphas = np.zeros(shape=n_learners)
@@ -69,74 +62,68 @@ def fit(self, X: np.ndarray, Y: np.ndarray, verbose: bool = False):
6962
7063
Parameters
7164
----------
72-
X: ndarray
73-
features having shape (n_samples, dim).
74-
Y: ndarray
75-
class labels having shape (n_samples,).
65+
X: np.ndarray
66+
Features having shape (n_samples, dim).
67+
Y: np.ndarray
68+
Class labels having shape (n_samples,).
7669
verbose: bool
77-
whether or not to visualize the learning process.
78-
Default is False
70+
Whether or not to visualize the learning process (default=False).
7971
"""
8072

81-
# some inits
82-
n, d = X.shape
83-
if d != 2:
84-
verbose = False # only plot learning if 2 dimensional
73+
n_examples, n_feats = X.shape
8574

86-
possible_labels = np.unique(Y)
75+
distinct_labels = len(np.unique(Y))
76+
if distinct_labels == 1:
77+
warnings.warn('Fitting {} on a dataset with only one label.'.format(
78+
self.__class__.__name__))
79+
elif distinct_labels > 2:
80+
raise NotImplementedError('Only binary classification is supported.')
8781

88-
# only binary problems please
89-
assert possible_labels.size == 2, 'Error: data is not binary'
82+
# Initialize all examples with equal weights
83+
weights = np.ones(shape=n_examples) / n_examples
9084

91-
# initialize the sample weights as equally probable
92-
sample_weights = np.ones(shape=n) / n
93-
94-
# start training
85+
# Train ensemble
9586
for l in range(self.n_learners):
96-
97-
# choose the indexes of 'difficult' samples (np.random.choice)
98-
cur_idx = np.random.choice(a=range(0, n), size=n, replace=True, p=sample_weights)
99-
100-
# extract 'difficult' samples
101-
cur_X = X[cur_idx]
102-
cur_Y = Y[cur_idx]
103-
104-
# search for a weak classifier
105-
error = 1
87+
# Perform a weighted re-sampling (with replacement) of the dataset
88+
# to create a new dataset on which the current weak learner will
89+
# be trained.
90+
sampled_idxs = choice(a=range(0, n_examples), size=n_examples,
91+
replace=True, p=weights)
92+
cur_X = X[sampled_idxs]
93+
cur_Y = Y[sampled_idxs]
94+
95+
# Search for a weak classifier
10696
n_trials = 0
107-
cur_wclass = None
108-
y_pred = None
109-
97+
error = 1.
11098
while error > 0.5:
99+
weak_learner = WeakClassifier()
100+
weak_learner.fit(cur_X, cur_Y)
101+
y_pred = weak_learner.predict(cur_X)
111102

112-
cur_wclass = WeakClassifier()
113-
cur_wclass.fit(cur_X, cur_Y)
114-
y_pred = cur_wclass.predict(cur_X)
115-
116-
# compute error
117-
error = np.sum(sample_weights[cur_idx[cur_Y != y_pred]])
103+
# Compute current weak learner error
104+
error = np.sum(weights[sampled_idxs[cur_Y != y_pred]])
118105

106+
# Re-initialize sample weights if number of trials is exceeded
119107
n_trials += 1
120108
if n_trials > self.n_max_trials:
121-
# initialize the sample weights again
122-
sample_weights = np.ones(shape=n) / n
109+
weights = np.ones(shape=n_examples) / n_examples
123110

124-
# save weak learner parameter
111+
# Store weak learner parameter
125112
self.alphas[l] = alpha = np.log((1 - error) / error) / 2
126113

127-
# append the weak classifier to the chain
128-
self.learners.append(cur_wclass)
114+
# Append the weak classifier to the chain
115+
self.learners.append(weak_learner)
129116

130-
# update sample weights
131-
sample_weights[cur_idx[cur_Y != y_pred]] *= np.exp(alpha)
132-
sample_weights[cur_idx[cur_Y == y_pred]] *= np.exp(-alpha)
133-
sample_weights /= np.sum(sample_weights)
117+
# Update examples weights
118+
weights[sampled_idxs[cur_Y != y_pred]] *= np.exp(alpha)
119+
weights[sampled_idxs[cur_Y == y_pred]] *= np.exp(-alpha)
120+
weights /= np.sum(weights) # re-normalize
134121

135-
if verbose:
136-
self._plot(cur_X, y_pred, sample_weights[cur_idx],
122+
# Possibly plot the predictions (if these are 2D)
123+
if verbose and n_feats == 2:
124+
self._plot(cur_X, y_pred, weights[sampled_idxs],
137125
self.learners[-1], l)
138126

139-
140127
def predict(self, X: np.ndarray):
141128
"""
142129
Function to perform predictions over a set of samples.
@@ -167,10 +154,10 @@ def predict(self, X: np.ndarray):
167154

168155
return pred
169156

170-
def _plot(self, X: np.ndarray, y_pred: np.ndarray, weights: np.ndarray,
171-
learner: WeakClassifier, iteration: int):
157+
@staticmethod
158+
def _plot(X: np.ndarray, y_pred: np.ndarray, weights: np.ndarray,
159+
learner: WeakClassifier, iteration: int, cmap: str = 'jet'):
172160

173-
# plot
174161
plt.clf()
175162
plt.scatter(X[:, 0], X[:, 1], c=y_pred, s=weights * 50000,
176163
cmap=cmap, edgecolors='k')

lab/boosting/code/main.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
import matplotlib.pyplot as plt
22
import numpy as np
3+
4+
from boosting import AdaBoostClassifier
35
from datasets import gaussians_dataset
46
from utils import plot_2d_dataset
57
from utils import plot_boundary
68

7-
from boosting import AdaBoostClassifier
8-
9-
plt.ion()
109

1110
def main_adaboost():
1211
"""
13-
Main function for testing Adaboost.
12+
Main function for fitting and testing Adaboost classifier.
1413
"""
15-
1614
X_train, Y_train, X_test, Y_test = gaussians_dataset(2, [300, 400], [[1, 3], [-4, 8]], [[2, 3], [4, 1]])
1715
# X_train, Y_train, X_test, Y_test = h_shaped_dataset()
1816
# X_train, Y_train, X_test, Y_test = two_moon_dataset(n_samples=500, noise=0.2)
1917

20-
# visualize dataset
21-
plot_2d_dataset(X_train, Y_train, 'Training')
18+
# Visualize dataset
19+
plot_2d_dataset(X_train, Y_train, 'Training', blocking=False)
2220

23-
# train model and predict
21+
# Init model
2422
model = AdaBoostClassifier(n_learners=100)
2523

24+
# Train
2625
model.fit(X_train, Y_train, verbose=True)
27-
P = model.predict(X_test)
2826

29-
# visualize the boundary!
30-
plot_boundary(X_train, Y_train, model)
27+
# Predict
28+
y_preds = model.predict(X_test)
29+
print('Accuracy on test set: {}'.format(np.mean(y_preds == Y_test)))
3130

32-
# evaluate and print error
33-
error = float(np.sum(P == Y_test)) / Y_test.size
34-
print('Test set - Classification Accuracy: {}'.format(error))
31+
# Visualize the predicted boundary
32+
plot_boundary(X_train, Y_train, model)
3533

3634

37-
# entry point
3835
if __name__ == '__main__':
36+
37+
plt.ion()
38+
3939
main_adaboost()

lab/boosting/code/utils.py

+31-27
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,76 @@
1-
import numpy as np
21
import matplotlib.pyplot as plt
3-
plt.ion()
4-
5-
cmap = 'jet'
2+
import numpy as np
63

74

8-
def plot_2d_dataset(X, Y, title=''):
5+
def plot_2d_dataset(X, Y, title='', cmap='jet', blocking: bool = False):
96
"""
107
Plots a two-dimensional dataset.
118
129
Parameters
1310
----------
14-
X: ndarray
15-
data points. (shape:(n_samples, dim))
16-
Y: ndarray
17-
groundtruth labels. (shape:(n_samples,))
11+
X: np.ndarray
12+
Data points. (shape:(n_samples, dim))
13+
Y: np.ndarray
14+
Groundtruth labels. (shape:(n_samples,))
1815
title: str
19-
an optional title for the plot.
16+
Optional title for the plot.
17+
cmap: str
18+
Colormap used for plotting
19+
blocking: bool
20+
When set, wait for user interaction
2021
"""
2122

22-
# new figure
2323
plt.figure()
2424

25-
# set lims
25+
# Compute and set range limits
2626
x_min = np.min(X[:, 0])
2727
x_max = np.max(X[:, 0])
2828
y_min = np.min(X[:, 1])
2929
y_max = np.max(X[:, 1])
3030
plt.xlim(x_min, x_max)
3131
plt.ylim(y_min, y_max)
3232

33-
# remove ticks
33+
# Remove ticks
3434
plt.xticks(())
3535
plt.yticks(())
3636

37-
# plot points
37+
# Plot points
3838
plt.scatter(X[:, 0], X[:, 1], c=Y, zorder=10, s=40, cmap=cmap, edgecolors='k')
3939
plt.title(title)
40-
plt.waitforbuttonpress()
40+
41+
if blocking:
42+
plt.waitforbuttonpress()
4143

4244

43-
def plot_boundary(X, Y, model, title=''):
45+
def plot_boundary(X, Y, model, title='', cmap='jet'):
4446
"""
4547
Represents the boundaries of a generic learning model over data.
4648
4749
Parameters
4850
----------
49-
X: ndarray
50-
data points. (shape:(n_samples, dim))
51-
Y: ndarray
52-
groundtruth labels. (shape:(n_samples,))
51+
X: np.ndarray
52+
Data points. (shape:(n_samples, dim))
53+
Y: np.ndarray
54+
Ground truth labels. (shape:(n_samples,))
5355
model: SVC
54-
A sklearn.SVC fit model.
56+
A sklearn classifier.
5557
title: str
56-
an optional title for the plot.
58+
Optional title for the plot.
59+
cmap: str
60+
Colormap used for plotting
5761
"""
5862

59-
# initialize subplots
63+
# Initialize subplots
6064
fig, ax = plt.subplots(1, 2)
6165
ax[0].scatter(X[:, 0], X[:, 1], c=Y, s=40, zorder=10, cmap=cmap, edgecolors='k')
6266

63-
# evaluate lims
67+
# Compute range limits
6468
x_min = np.min(X[:, 0])
6569
x_max = np.max(X[:, 0])
6670
y_min = np.min(X[:, 1])
6771
y_max = np.max(X[:, 1])
6872

69-
# predict all over a grid
73+
# Predict all over a dense grid
7074
XX, YY = np.mgrid[x_min:x_max:500j, y_min:y_max:500j]
7175
Z = model.predict(np.c_[XX.ravel(), YY.ravel()])
7276

@@ -75,14 +79,14 @@ def plot_boundary(X, Y, model, title=''):
7579
ax[1].pcolormesh(XX, YY, Z, cmap=plt.cm.Paired)
7680
ax[1].scatter(X[:, 0], X[:, 1], c=Y, s=40, zorder=10, cmap=cmap, edgecolors='k')
7781

78-
# set stuff for subplots
82+
# Set limits and ticks for each subplot
7983
for s in [0, 1]:
8084
ax[s].set_xlim([x_min, x_max])
8185
ax[s].set_ylim([y_min, y_max])
8286
ax[s].set_xticks([])
8387
ax[s].set_yticks([])
8488

85-
ax[0].set_title('Data')
89+
ax[0].set_title(title)
8690
ax[1].set_title('Boundary')
8791

8892
plt.waitforbuttonpress()

0 commit comments

Comments
 (0)