Skip to content

Commit cc4b13a

Browse files
committedAug 18, 2023
Linear Regression.
1 parent db61633 commit cc4b13a

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed
 

‎Linear Regression.ipynb

+177
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"This notebook illustrates some features of the JAX library in the context of a simple linear regression problem. In real life, we could fit this model much more simply by using the the least squares estimator\n",
8+
"$$\n",
9+
"\\hat{\\beta}=(X^T X)^{-1}X^T y,\n",
10+
"$$\n",
11+
"but here we will optimize the mean-square error loss function via gradient descent."
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 1,
17+
"metadata": {},
18+
"outputs": [
19+
{
20+
"name": "stdout",
21+
"output_type": "stream",
22+
"text": [
23+
"seed is 1692330859\n"
24+
]
25+
}
26+
],
27+
"source": [
28+
"import jax\n",
29+
"import jax.numpy as jnp\n",
30+
"import jax.random as random\n",
31+
"from collections import namedtuple\n",
32+
"import time\n",
33+
"SEED = int(time.time())\n",
34+
"print(f\"seed is {SEED}\")\n",
35+
"key = random.key(SEED)\n",
36+
"ModelParameters = namedtuple('ModelParameters', 'w b')"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 2,
42+
"metadata": {},
43+
"outputs": [],
44+
"source": [
45+
"@jax.jit\n",
46+
"def predict(params: ModelParameters, x: jnp.array) -> jnp.array:\n",
47+
" return params.w.dot(x) + params.b\n",
48+
"vpredict = jax.vmap(predict, in_axes=[None, 0])"
49+
]
50+
},
51+
{
52+
"cell_type": "markdown",
53+
"metadata": {},
54+
"source": [
55+
"JAX random numbers are a bit weird -- we have to push around some state in the `key` variable."
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": 3,
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"xs = random.normal(key, shape=(200,1))\n",
65+
"key, _ = random.split(key)\n",
66+
"Wtrue = random.normal(key, shape=(1,))\n",
67+
"key, _ = random.split(key)\n",
68+
"btrue = random.normal(key, shape=(1,))\n",
69+
"true_params = ModelParameters(Wtrue, btrue)\n",
70+
"true_ys = vpredict(true_params, xs)\n",
71+
"\n",
72+
"key, _ = random.split(key)\n",
73+
"W = random.normal(key, shape=(1,))\n",
74+
"key, _ = random.split(key)\n",
75+
"b = random.normal(key, shape=(1,))\n",
76+
"params = ModelParameters(W, b)"
77+
]
78+
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"Here we define our loss function, the mean of the square of the errors."
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": 4,
89+
"metadata": {},
90+
"outputs": [],
91+
"source": [
92+
"@jax.jit\n",
93+
"def mse(parameters: ModelParameters, xs: jnp.array, ys: jnp.array) -> jnp.array:\n",
94+
" y_hats = vpredict(parameters, xs)\n",
95+
" return jax.numpy.mean(jnp.square(y_hats - ys))\n",
96+
"grad_mse = jax.grad(mse)"
97+
]
98+
},
99+
{
100+
"cell_type": "markdown",
101+
"metadata": {},
102+
"source": [
103+
"Below the model is fitted."
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": 5,
109+
"metadata": {},
110+
"outputs": [
111+
{
112+
"name": "stdout",
113+
"output_type": "stream",
114+
"text": [
115+
"ModelParameters(w=Array([0.2682777], dtype=float32), b=Array([0.1782908], dtype=float32))\n",
116+
"ModelParameters(w=Array([1.8501179], dtype=float32), b=Array([0.6013752], dtype=float32))\n",
117+
"ModelParameters(w=Array([2.0439496], dtype=float32), b=Array([0.6323448], dtype=float32))\n",
118+
"ModelParameters(w=Array([2.0680373], dtype=float32), b=Array([0.63335055], dtype=float32))\n",
119+
"ModelParameters(w=Array([2.0710773], dtype=float32), b=Array([0.6330958], dtype=float32))\n",
120+
"ModelParameters(w=Array([2.0714667], dtype=float32), b=Array([0.63301265], dtype=float32))\n",
121+
"ModelParameters(w=Array([2.0715175], dtype=float32), b=Array([0.6329955], dtype=float32))\n",
122+
"ModelParameters(w=Array([2.0715194], dtype=float32), b=Array([0.6329933], dtype=float32))\n",
123+
"ModelParameters(w=Array([2.0715194], dtype=float32), b=Array([0.6329933], dtype=float32))\n",
124+
"ModelParameters(w=Array([2.0715194], dtype=float32), b=Array([0.6329933], dtype=float32))\n"
125+
]
126+
}
127+
],
128+
"source": [
129+
"lr = 1e-2\n",
130+
"for i in range(1000):\n",
131+
" batch_grads = grad_mse(params, xs, true_ys)\n",
132+
" params = ModelParameters(params.w - lr * batch_grads.w, params.b - lr * batch_grads.b)\n",
133+
" if i % 100 == 0:\n",
134+
" print(params)"
135+
]
136+
},
137+
{
138+
"cell_type": "markdown",
139+
"metadata": {},
140+
"source": [
141+
"Finally, let's compare the true parameters to the learned ones."
142+
]
143+
},
144+
{
145+
"cell_type": "code",
146+
"execution_count": 6,
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"assert jnp.isclose(true_params.w, params.w)\n",
151+
"assert jnp.isclose(true_params.b, params.b)"
152+
]
153+
}
154+
],
155+
"metadata": {
156+
"kernelspec": {
157+
"display_name": ".venv",
158+
"language": "python",
159+
"name": "python3"
160+
},
161+
"language_info": {
162+
"codemirror_mode": {
163+
"name": "ipython",
164+
"version": 3
165+
},
166+
"file_extension": ".py",
167+
"mimetype": "text/x-python",
168+
"name": "python",
169+
"nbconvert_exporter": "python",
170+
"pygments_lexer": "ipython3",
171+
"version": "3.10.12"
172+
},
173+
"orig_nbformat": 4
174+
},
175+
"nbformat": 4,
176+
"nbformat_minor": 2
177+
}

0 commit comments

Comments
 (0)
Please sign in to comment.