Skip to content

Commit bed23a9

Browse files
committed
Update exericses and slides
1 parent 9cdfe16 commit bed23a9

File tree

4 files changed

+127
-25
lines changed

4 files changed

+127
-25
lines changed

src/scipy_dev/notebooks/07_automatic_differentiation.ipynb

+10-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,16 @@
6969
"\n",
7070
"## Task 2 (Windows): Gradient\n",
7171
"\n",
72-
"- Compute the gradient of the criterion (the whole function) analytically\n",
73-
"- Implement the analytical gradient"
72+
"The analytical gradient of the function is given by:\n",
73+
"\n",
74+
"- $\\partial_a f(a, b, C) = 2 (a - \\pi)$\n",
75+
"- $\\partial_b f(a, b, C) = 2 (b - \\begin{pmatrix}0,1,2\\end{pmatrix}^\\top)$\n",
76+
"- $\\partial_C f(a, b, C) = 2 (C - I_2)$\n",
77+
"\n",
78+
"---\n",
79+
"\n",
80+
"- Implement the analytical gradient\n",
81+
" - return the gradient in the form of `{\"a\": ..., \"b\": ..., \"C\": ...}`"
7482
]
7583
},
7684
{

src/scipy_dev/notebooks/solutions/07_automatic_differentiation.ipynb

+113-19
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
"\n",
1010
"In this exercise you will use automatic differentiation in JAX and estimagic to solve the previous problem.\n",
1111
"\n",
12-
"> Note. Here you will only find the solution for Unix and Linux.\n",
13-
"\n",
1412
"## Resources\n",
1513
"\n",
1614
"- https://jax.readthedocs.io/en/latest/jax.numpy.html\n",
@@ -49,7 +47,7 @@
4947
"outputs": [],
5048
"source": [
5149
"def criterion(x):\n",
52-
" first = (x[\"a\"] - jnp.pi) ** 4 \n",
50+
" first = (x[\"a\"] - jnp.pi) ** 2\n",
5351
" second = jnp.sum((x[\"b\"] - jnp.arange(3)) ** 2)\n",
5452
" third = jnp.sum((x[\"c\"] - jnp.eye(2)) ** 2)\n",
5553
" return first + second + third\n",
@@ -71,7 +69,7 @@
7169
{
7270
"data": {
7371
"text/plain": [
74-
"DeviceArray(25.0352401, dtype=float64)"
72+
"DeviceArray(8.58641909, dtype=float64)"
7573
]
7674
},
7775
"execution_count": 3,
@@ -83,32 +81,63 @@
8381
"criterion(start_params)"
8482
]
8583
},
84+
{
85+
"cell_type": "markdown",
86+
"id": "c690e3bf",
87+
"metadata": {},
88+
"source": [
89+
"## Solution, Task 1 (Windows):"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 4,
95+
"id": "22bfb278",
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"import numpy as np\n",
100+
"\n",
101+
"def criterion_windows(x):\n",
102+
" first = (x[\"a\"] - jnp.pi) ** 2\n",
103+
" second = np.sum((x[\"b\"] - np.arange(3)) ** 2)\n",
104+
" third = np.sum((x[\"c\"] - np.eye(2)) ** 2)\n",
105+
" return first + second + third\n",
106+
" \n",
107+
" \n",
108+
"start_params_windows = {\n",
109+
" \"a\": 1.,\n",
110+
" \"b\": np.ones(3).astype(float),\n",
111+
" \"c\": np.ones((2, 2)).astype(float)\n",
112+
"}"
113+
]
114+
},
86115
{
87116
"cell_type": "markdown",
88117
"id": "9c2814c9",
89118
"metadata": {},
90119
"source": [
91-
"## Task 2: Gradient\n",
120+
"## Solution, Task 2: Gradient\n",
92121
"\n",
93122
"- Compute the gradient of the criterion (the whole function). Hint: look at the [`autodiff_cookbook` documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) and slides if you have any questions."
94123
]
95124
},
96125
{
97126
"cell_type": "code",
98-
"execution_count": 4,
127+
"execution_count": 5,
99128
"id": "122f2831",
100129
"metadata": {},
101130
"outputs": [
102131
{
103132
"data": {
104133
"text/plain": [
105-
"{'a': DeviceArray(-39.28896575, dtype=float64, weak_type=True),\n",
134+
"{'a': DeviceArray(-4.28318531, dtype=float64, weak_type=True),\n",
106135
" 'b': DeviceArray([ 2., 0., -2.], dtype=float64),\n",
107136
" 'c': DeviceArray([[0., 2.],\n",
108137
" [2., 0.]], dtype=float64)}"
109138
]
110139
},
111-
"execution_count": 4,
140+
"execution_count": 5,
112141
"metadata": {},
113142
"output_type": "execute_result"
114143
}
@@ -120,15 +149,15 @@
120149
},
121150
{
122151
"cell_type": "code",
123-
"execution_count": 5,
152+
"execution_count": 6,
124153
"id": "7aefa2e9",
125154
"metadata": {},
126155
"outputs": [
127156
{
128157
"name": "stdout",
129158
"output_type": "stream",
130159
"text": [
131-
"8.25 ms ± 975 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
160+
"11.5 ms ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
132161
]
133162
}
134163
],
@@ -138,15 +167,15 @@
138167
},
139168
{
140169
"cell_type": "code",
141-
"execution_count": 6,
170+
"execution_count": 7,
142171
"id": "dd8ffcc6",
143172
"metadata": {},
144173
"outputs": [
145174
{
146175
"name": "stdout",
147176
"output_type": "stream",
148177
"text": [
149-
"10.7 µs ± 248 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
178+
"17.2 µs ± 7.57 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
150179
]
151180
}
152181
],
@@ -155,12 +184,46 @@
155184
"%timeit jitted_gradient(start_params)"
156185
]
157186
},
187+
{
188+
"cell_type": "markdown",
189+
"id": "92f96dcc",
190+
"metadata": {},
191+
"source": [
192+
"## Solution, Task 2 (Windows):\n",
193+
"\n",
194+
"The analytical gradient of the function is given by:\n",
195+
"\n",
196+
"- $\\partial_a f(a, b, C) = 2 (a - \\pi)$\n",
197+
"- $\\partial_b f(a, b, C) = 2 (b - \\begin{pmatrix}0,1,2\\end{pmatrix}^\\top)$\n",
198+
"- $\\partial_C f(a, b, C) = 2 (C - I_2)$\n",
199+
"\n",
200+
"---\n",
201+
"\n",
202+
"- Implement the analytical gradient\n",
203+
" - return the gradient in the form of `{\"a\": ..., \"b\": ..., \"C\": ...}`"
204+
]
205+
},
206+
{
207+
"cell_type": "code",
208+
"execution_count": 8,
209+
"id": "2201091d",
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"def gradient(params):\n",
214+
" return {\n",
215+
" \"a\": 2 * (params[\"a\"] - np.pi),\n",
216+
" \"b\": 2 * (params[\"b\"] - np.array([0, 1, 2])),\n",
217+
" \"c\": 2 * (params[\"c\"] - np.eye(2))\n",
218+
" }"
219+
]
220+
},
158221
{
159222
"cell_type": "markdown",
160223
"id": "e9e578b5",
161224
"metadata": {},
162225
"source": [
163-
"## Task 3: Minimize\n",
226+
"## Solution, Task 3: Minimize\n",
164227
"\n",
165228
"- Use estimagic to minimize the criterion\n",
166229
" - pass the gradient function you computed above to the minimize call.\n",
@@ -169,20 +232,20 @@
169232
},
170233
{
171234
"cell_type": "code",
172-
"execution_count": 7,
235+
"execution_count": 9,
173236
"id": "f23ead7a",
174237
"metadata": {},
175238
"outputs": [
176239
{
177240
"data": {
178241
"text/plain": [
179-
"{'a': 3.1292550669508072,\n",
180-
" 'b': DeviceArray([-4.86427306e-06, 1.00000000e+00, 1.99999782e+00], dtype=float64),\n",
181-
" 'c': DeviceArray([[ 1.00000000e+00, -4.86427306e-06],\n",
182-
" [-4.86427306e-06, 1.00000000e+00]], dtype=float64)}"
242+
"{'a': 3.141592653589793,\n",
243+
" 'b': DeviceArray([3.33066907e-16, 1.00000000e+00, 2.00000000e+00], dtype=float64),\n",
244+
" 'c': DeviceArray([[1.00000000e+00, 3.33066907e-16],\n",
245+
" [3.33066907e-16, 1.00000000e+00]], dtype=float64)}"
183246
]
184247
},
185-
"execution_count": 7,
248+
"execution_count": 9,
186249
"metadata": {},
187250
"output_type": "execute_result"
188251
}
@@ -197,6 +260,37 @@
197260
"\n",
198261
"res.params"
199262
]
263+
},
264+
{
265+
"cell_type": "code",
266+
"execution_count": 10,
267+
"id": "1ef9fc6e",
268+
"metadata": {},
269+
"outputs": [
270+
{
271+
"data": {
272+
"text/plain": [
273+
"{'a': 3.141592653589793,\n",
274+
" 'b': array([3.33066907e-16, 1.00000000e+00, 2.00000000e+00]),\n",
275+
" 'c': array([[1.00000000e+00, 3.33066907e-16],\n",
276+
" [3.33066907e-16, 1.00000000e+00]])}"
277+
]
278+
},
279+
"execution_count": 10,
280+
"metadata": {},
281+
"output_type": "execute_result"
282+
}
283+
],
284+
"source": [
285+
"res = em.minimize(\n",
286+
" criterion=criterion_windows,\n",
287+
" derivative=gradient,\n",
288+
" params=start_params_windows,\n",
289+
" algorithm=\"scipy_lbfgsb\",\n",
290+
")\n",
291+
"\n",
292+
"res.params"
293+
]
200294
}
201295
],
202296
"metadata": {

src/scipy_dev/presentation/main.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1770,8 +1770,8 @@ Economic problem:
17701770
>>> import jax.numpy as jnp
17711771
>>> from jaxopt import LBFGS
17721772

1773-
>>> x0 = jnp.array([1.0, 2, 3])
1774-
>>> shift = x0.copy()
1773+
>>> x0 = jnp.array([1.0, -2, -5])
1774+
>>> shift = jnp.array([-2.0, -4, -6])
17751775

17761776
>>> def criterion(x, shift):
17771777
... return jnp.vdot(x, x + shift)
@@ -1780,7 +1780,7 @@ Economic problem:
17801780

17811781
>>> result = solver.run(init_params=x0, shift=shift)
17821782
>>> result.params
1783-
DeviceArray([-0.5, -1. , -1.5], dtype=float64)
1783+
DeviceArray([1.0, 2.0, 3.0], dtype=float64)
17841784
```
17851785

17861786
</div>

src/scipy_dev/source_repo/test_installation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def is_installed(executable):
1010
# Check Installations
1111
# ======================================================================================
1212

13-
required_estimagic_version = "0.3.2"
13+
required_estimagic_version = "0.4.0"
1414

1515
try:
1616
import estimagic # noqa: F401

0 commit comments

Comments
 (0)