Skip to content

Commit 4f37707

Browse files
authored
Merge tests into master (#50)
* typo fixes * rename tests to plots * add test dir * separate out plots from other tests * consistent handling of ties during KNN classification * consistent handling of ties during KNN classification * add filter_punctuation flag * move plots to a separate directory * remove test module * add warning for plotting and gym dependencies * add setup.py and misc requirements for pip packaging * fix package name and update readme * Update README.md * add available models to top-level readme * Update README.md * update installation documentation * add best_arm to bandit oracle output * fix github link
1 parent 7879246 commit 4f37707

38 files changed

+1525
-1186
lines changed

MANIFEST.in

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
include README.md
2+
include requirements*.txt
3+
include docs/*.rst
4+
include docs/img/*.png

README.md

+170-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,177 @@
11
# numpy-ml
22
Ever wish you had an inefficient but somewhat legible collection of machine
3-
learning algorithms implemented exclusively in numpy? No?
3+
learning algorithms implemented exclusively in NumPy? No?
4+
5+
## Installation
6+
7+
### For rapid experimentation
8+
To use this code as a starting point for ML prototyping / experimentation, just clone the repository, create a new [virtualenv](https://pypi.org/project/virtualenv/), and start hacking:
9+
10+
```sh
11+
$ git clone https://github.com/ddbourgin/numpy-ml.git
12+
$ cd numpy-ml && virtualenv npml && source npml/bin/activate
13+
$ pip3 install -r requirements-dev.txt
14+
```
15+
16+
### As a package
17+
If you don't plan to modify the source, you can also install numpy-ml as a
18+
Python package: `pip3 install -u numpy_ml`.
19+
20+
The reinforcement learning agents train on environments defined in the [OpenAI
21+
gym](https://github.com/openai/gym). To install these alongside numpy-ml, you
22+
can use `pip3 install -u 'numpy_ml[rl]'`.
423

524
## Documentation
6-
To see all of the available models, take a look at the [project documentation](https://numpy-ml.readthedocs.io/) or see [here](https://github.com/ddbourgin/numpy-ml/blob/master/numpy_ml/README.md).
25+
For more details on the available models, see the [project documentation](https://numpy-ml.readthedocs.io/).
26+
27+
## Available models
28+
1. **Gaussian mixture model**
29+
- EM training
30+
31+
2. **Hidden Markov model**
32+
- Viterbi decoding
33+
- Likelihood computation
34+
- MLE parameter estimation via Baum-Welch/forward-backward algorithm
35+
36+
3. **Latent Dirichlet allocation** (topic model)
37+
- Standard model with MLE parameter estimation via variational EM
38+
- Smoothed model with MAP parameter estimation via MCMC
39+
40+
4. **Neural networks**
41+
* Layers / Layer-wise ops
42+
- Add
43+
- Flatten
44+
- Multiply
45+
- Softmax
46+
- Fully-connected/Dense
47+
- Sparse evolutionary connections
48+
- LSTM
49+
- Elman-style RNN
50+
- Max + average pooling
51+
- Dot-product attention
52+
- Embedding layer
53+
- Restricted Boltzmann machine (w. CD-n training)
54+
- 2D deconvolution (w. padding and stride)
55+
- 2D convolution (w. padding, dilation, and stride)
56+
- 1D convolution (w. padding, dilation, stride, and causality)
57+
* Modules
58+
- Bidirectional LSTM
59+
- ResNet-style residual blocks (identity and convolution)
60+
- WaveNet-style residual blocks with dilated causal convolutions
61+
- Transformer-style multi-headed scaled dot product attention
62+
* Regularizers
63+
- Dropout
64+
* Normalization
65+
- Batch normalization (spatial and temporal)
66+
- Layer normalization (spatial and temporal)
67+
* Optimizers
68+
- SGD w/ momentum
69+
- AdaGrad
70+
- RMSProp
71+
- Adam
72+
* Learning Rate Schedulers
73+
- Constant
74+
- Exponential
75+
- Noam/Transformer
76+
- Dlib scheduler
77+
* Weight Initializers
78+
- Glorot/Xavier uniform and normal
79+
- He/Kaiming uniform and normal
80+
- Standard and truncated normal
81+
* Losses
82+
- Cross entropy
83+
- Squared error
84+
- Bernoulli VAE loss
85+
- Wasserstein loss with gradient penalty
86+
- Noise contrastive estimation loss
87+
* Activations
88+
- ReLU
89+
- Tanh
90+
- Affine
91+
- Sigmoid
92+
- Leaky ReLU
93+
- ELU
94+
- SELU
95+
- Exponential
96+
- Hard Sigmoid
97+
- Softplus
98+
* Models
99+
- Bernoulli variational autoencoder
100+
- Wasserstein GAN with gradient penalty
101+
- word2vec encoder with skip-gram and CBOW architectures
102+
* Utilities
103+
- `col2im` (MATLAB port)
104+
- `im2col` (MATLAB port)
105+
- `conv1D`
106+
- `conv2D`
107+
- `deconv2D`
108+
- `minibatch`
109+
110+
5. **Tree-based models**
111+
- Decision trees (CART)
112+
- [Bagging] Random forests
113+
- [Boosting] Gradient-boosted decision trees
114+
115+
6. **Linear models**
116+
- Ridge regression
117+
- Logistic regression
118+
- Ordinary least squares
119+
- Bayesian linear regression w/ conjugate priors
120+
- Unknown mean, known variance (Gaussian prior)
121+
- Unknown mean, unknown variance (Normal-Gamma / Normal-Inverse-Wishart prior)
122+
123+
7. **n-Gram sequence models**
124+
- Maximum likelihood scores
125+
- Additive/Lidstone smoothing
126+
- Simple Good-Turing smoothing
127+
128+
8. **Multi-armed bandit models**
129+
- UCB1
130+
- LinUCB
131+
- Epsilon-greedy
132+
- Thompson sampling w/ conjugate priors
133+
- Beta-Bernoulli sampler
134+
- LinUCB
135+
136+
8. **Reinforcement learning models**
137+
- Cross-entropy method agent
138+
- First visit on-policy Monte Carlo agent
139+
- Weighted incremental importance sampling Monte Carlo agent
140+
- Expected SARSA agent
141+
- TD-0 Q-learning agent
142+
- Dyna-Q / Dyna-Q+ with prioritized sweeping
143+
144+
9. **Nonparameteric models**
145+
- Nadaraya-Watson kernel regression
146+
- k-Nearest neighbors classification and regression
147+
- Gaussian process regression
148+
149+
10. **Matrix factorization**
150+
- Regularized alternating least-squares
151+
- Non-negative matrix factorization
152+
153+
11. **Preprocessing**
154+
- Discrete Fourier transform (1D signals)
155+
- Discrete cosine transform (type-II) (1D signals)
156+
- Bilinear interpolation (2D signals)
157+
- Nearest neighbor interpolation (1D and 2D signals)
158+
- Autocorrelation (1D signals)
159+
- Signal windowing
160+
- Text tokenization
161+
- Feature hashing
162+
- Feature standardization
163+
- One-hot encoding / decoding
164+
- Huffman coding / decoding
165+
- Term frequency-inverse document frequency (TF-IDF) encoding
166+
- MFCC encoding
167+
168+
12. **Utilities**
169+
- Similarity kernels
170+
- Distance metrics
171+
- Priority queue
172+
- Ball tree
173+
- Discrete sampler
174+
- Graph processing and generators
7175

8176
## Contributing
9177

numpy_ml/bandits/bandits.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from ..utils.testing import random_one_hot_matrix, is_number
7+
from numpy_ml.utils.testing import random_one_hot_matrix, is_number
88

99

1010
class Bandit(ABC):
@@ -104,6 +104,7 @@ def __init__(self, payoffs, payoff_probs):
104104
self.payoff_probs = payoff_probs
105105
self.arm_evs = np.array([sum(p * v) for p, v in zip(payoff_probs, payoffs)])
106106
self.best_ev = np.max(self.arm_evs)
107+
self.best_arm = np.argmax(self.arm_evs)
107108

108109
@property
109110
def hyperparameters(self):
@@ -127,8 +128,10 @@ def oracle_payoff(self, context=None):
127128
-------
128129
optimal_rwd : float
129130
The expected reward under an optimal policy.
131+
optimal_arm : float
132+
The arm ID with the largest expected reward.
130133
"""
131-
return self.best_ev
134+
return self.best_ev, self.best_arm
132135

133136
def _pull(self, arm_id, context):
134137
payoffs = self.payoffs[arm_id]
@@ -159,6 +162,7 @@ def __init__(self, payoff_probs):
159162

160163
self.arm_evs = self.payoff_probs
161164
self.best_ev = np.max(self.arm_evs)
165+
self.best_arm = np.argmax(self.arm_evs)
162166

163167
@property
164168
def hyperparameters(self):
@@ -181,8 +185,10 @@ def oracle_payoff(self, context=None):
181185
-------
182186
optimal_rwd : float
183187
The expected reward under an optimal policy.
188+
optimal_arm : float
189+
The arm ID with the largest expected reward.
184190
"""
185-
return self.best_ev
191+
return self.best_ev, self.best_arm
186192

187193
def _pull(self, arm_id, context):
188194
return int(np.random.rand() <= self.payoff_probs[arm_id])
@@ -217,6 +223,7 @@ def __init__(self, payoff_dists, payoff_probs):
217223
self.payoff_probs = payoff_probs
218224
self.arm_evs = np.array([mu for (mu, var) in payoff_dists])
219225
self.best_ev = np.max(self.arm_evs)
226+
self.best_arm = np.argmax(self.arm_evs)
220227

221228
@property
222229
def hyperparameters(self):
@@ -249,8 +256,10 @@ def oracle_payoff(self, context=None):
249256
-------
250257
optimal_rwd : float
251258
The expected reward under an optimal policy.
259+
optimal_arm : float
260+
The arm ID with the largest expected reward.
252261
"""
253-
return self.best_ev
262+
return self.best_ev, self.best_arm
254263

255264

256265
class ShortestPathBandit(Bandit):
@@ -282,6 +291,7 @@ def __init__(self, G, start_vertex, end_vertex):
282291

283292
self.arm_evs = self._calc_arm_evs()
284293
self.best_ev = np.max(self.arm_evs)
294+
self.best_arm = np.argmax(self.arm_evs)
285295

286296
placeholder = [None] * len(self.paths)
287297
super().__init__(placeholder, placeholder)
@@ -309,8 +319,10 @@ def oracle_payoff(self, context=None):
309319
-------
310320
optimal_rwd : float
311321
The expected reward under an optimal policy.
322+
optimal_arm : float
323+
The arm ID with the largest expected reward.
312324
"""
313-
return self.best_ev
325+
return self.best_ev, self.best_arm
314326

315327
def _calc_arm_evs(self):
316328
I2V = self.G.get_vertex
@@ -353,7 +365,8 @@ def __init__(self, context_probs):
353365

354366
self.context_probs = context_probs
355367
self.arm_evs = self.context_probs
356-
self.best_ev = self.arm_evs.max(axis=1)
368+
self.best_evs = self.arm_evs.max(axis=1)
369+
self.best_arms = self.arm_evs.argmax(axis=1)
357370

358371
@property
359372
def hyperparameters(self):
@@ -386,15 +399,17 @@ def oracle_payoff(self, context):
386399
Parameters
387400
----------
388401
context : :py:class:`ndarray <numpy.ndarray>` of shape `(D, K)` or None
389-
The current context matrix for each of the bandit arms, if
390-
applicable. Default is None.
402+
The current context matrix for each of the bandit arms.
391403
392404
Returns
393405
-------
394406
optimal_rwd : float
395407
The expected reward under an optimal policy.
408+
optimal_arm : float
409+
The arm ID with the largest expected reward.
396410
"""
397-
return context[:, 0] @ self.best_ev
411+
context_id = context[:, 0].argmax()
412+
return self.best_evs[context_id], self.best_arms[context_id]
398413

399414
def _pull(self, arm_id, context):
400415
D, K = self.context_probs.shape
@@ -499,9 +514,11 @@ def oracle_payoff(self, context):
499514
-------
500515
optimal_rwd : float
501516
The expected reward under an optimal policy.
517+
optimal_arm : float
518+
The arm ID with the largest expected reward.
502519
"""
503520
best_arm = np.argmax(self.arm_evs)
504-
return self.arm_evs[best_arm]
521+
return self.arm_evs[best_arm], best_arm
505522

506523
def _pull(self, arm_id, context):
507524
K, thetas = self.K, self.thetas

numpy_ml/bandits/policies.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,12 @@ def __init__(self, C=1, ev_prior=0.5):
202202
203203
\text{UCB}(a, t) = \text{EV}_t(a) + C \sqrt{\frac{2 \log t}{N_t(a)}}
204204
205-
where :math:`\text{UCB}(a, t)` is the upper confidence bound on the
206-
expected value of arm `a` at time `t`, :math:`\text{EV}_t(a)` is the
207-
average of the rewards recieved so far from pulling arm `a`, `C` is a
208-
parameter controlling the confidence upper bound of the estimate for
209-
:math:`\text{UCB}(a, t)` (for logarithmic regret bounds, `C` must
210-
equal 1), and :math:`N_t(a)` is the number of times arm `a` has been
211-
pulled during the previous `t - 1` timesteps.
205+
where :math:`\text{EV}_t(a)` is the average of the rewards recieved so
206+
far from pulling arm `a`, `C` is a free parameter controlling the
207+
"optimism" of the confidence upper bound for :math:`\text{UCB}(a, t)`
208+
(for logarithmic regret bounds, `C` must equal 1), and :math:`N_t(a)`
209+
is the number of times arm `a` has been pulled during the previous `t -
210+
1` timesteps.
212211
213212
References
214213
----------
@@ -220,7 +219,8 @@ def __init__(self, C=1, ev_prior=0.5):
220219
----------
221220
C : float in (0, +infinity)
222221
A confidence/optimisim parameter affecting the degree of
223-
exploration. The UCB1 algorithm assumes `C=1`. Default is 1.
222+
exploration, where larger values encourage greater exploration. The
223+
UCB1 algorithm assumes `C=1`. Default is 1.
224224
ev_prior : float
225225
The starting expected value for each arm before any data has been
226226
observed. Default is 0.5.
@@ -292,10 +292,10 @@ def __init__(self, alpha=1, beta=1):
292292
where :math:`k \in \{1,\ldots,K \}` indexes arms in the MAB and
293293
:math:`\theta_k` is the parameter of the Bernoulli likelihood for arm
294294
`k`. The sampler begins by selecting an arm with probability
295-
proportional to it's payoff probability under the initial Beta prior.
295+
proportional to its payoff probability under the initial Beta prior.
296296
After pulling the sampled arm and receiving a reward, `r`, the sampler
297297
computes the posterior over the model parameters (arm payoffs) via
298-
Bayes' rule, and then samples a new action in proportion to it's payoff
298+
Bayes' rule, and then samples a new action in proportion to its payoff
299299
probability under this posterior. This process (i.e., sample action
300300
from posterior, take action and receive reward, compute updated
301301
posterior) is repeated until the number of trials is exhausted.

0 commit comments

Comments
 (0)