|
9 | 9 | "\n",
|
10 | 10 | "In this exercise you will use automatic differentiation in JAX and estimagic to solve the previous problem.\n",
|
11 | 11 | "\n",
|
12 |
| - "> Note. Here you will only find the solution for Unix and Linux.\n", |
13 |
| - "\n", |
14 | 12 | "## Resources\n",
|
15 | 13 | "\n",
|
16 | 14 | "- https://jax.readthedocs.io/en/latest/jax.numpy.html\n",
|
|
49 | 47 | "outputs": [],
|
50 | 48 | "source": [
|
51 | 49 | "def criterion(x):\n",
|
52 |
| - " first = (x[\"a\"] - jnp.pi) ** 4 \n", |
| 50 | + " first = (x[\"a\"] - jnp.pi) ** 2\n", |
53 | 51 | " second = jnp.sum((x[\"b\"] - jnp.arange(3)) ** 2)\n",
|
54 | 52 | " third = jnp.sum((x[\"c\"] - jnp.eye(2)) ** 2)\n",
|
55 | 53 | " return first + second + third\n",
|
|
71 | 69 | {
|
72 | 70 | "data": {
|
73 | 71 | "text/plain": [
|
74 |
| - "DeviceArray(25.0352401, dtype=float64)" |
| 72 | + "DeviceArray(8.58641909, dtype=float64)" |
75 | 73 | ]
|
76 | 74 | },
|
77 | 75 | "execution_count": 3,
|
|
83 | 81 | "criterion(start_params)"
|
84 | 82 | ]
|
85 | 83 | },
|
| 84 | + { |
| 85 | + "cell_type": "markdown", |
| 86 | + "id": "c690e3bf", |
| 87 | + "metadata": {}, |
| 88 | + "source": [ |
| 89 | + "## Solution, Task 1 (Windows):" |
| 90 | + ] |
| 91 | + }, |
| 92 | + { |
| 93 | + "cell_type": "code", |
| 94 | + "execution_count": 4, |
| 95 | + "id": "22bfb278", |
| 96 | + "metadata": {}, |
| 97 | + "outputs": [], |
| 98 | + "source": [ |
| 99 | + "import numpy as np\n", |
| 100 | + "\n", |
| 101 | + "def criterion_windows(x):\n", |
| 102 | + " first = (x[\"a\"] - jnp.pi) ** 2\n", |
| 103 | + " second = np.sum((x[\"b\"] - np.arange(3)) ** 2)\n", |
| 104 | + " third = np.sum((x[\"c\"] - np.eye(2)) ** 2)\n", |
| 105 | + " return first + second + third\n", |
| 106 | + " \n", |
| 107 | + " \n", |
| 108 | + "start_params_windows = {\n", |
| 109 | + " \"a\": 1.,\n", |
| 110 | + " \"b\": np.ones(3).astype(float),\n", |
| 111 | + " \"c\": np.ones((2, 2)).astype(float)\n", |
| 112 | + "}" |
| 113 | + ] |
| 114 | + }, |
86 | 115 | {
|
87 | 116 | "cell_type": "markdown",
|
88 | 117 | "id": "9c2814c9",
|
89 | 118 | "metadata": {},
|
90 | 119 | "source": [
|
91 |
| - "## Task 2: Gradient\n", |
| 120 | + "## Solution, Task 2: Gradient\n", |
92 | 121 | "\n",
|
93 | 122 | "- Compute the gradient of the criterion (the whole function). Hint: look at the [`autodiff_cookbook` documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) and slides if you have any questions."
|
94 | 123 | ]
|
95 | 124 | },
|
96 | 125 | {
|
97 | 126 | "cell_type": "code",
|
98 |
| - "execution_count": 4, |
| 127 | + "execution_count": 5, |
99 | 128 | "id": "122f2831",
|
100 | 129 | "metadata": {},
|
101 | 130 | "outputs": [
|
102 | 131 | {
|
103 | 132 | "data": {
|
104 | 133 | "text/plain": [
|
105 |
| - "{'a': DeviceArray(-39.28896575, dtype=float64, weak_type=True),\n", |
| 134 | + "{'a': DeviceArray(-4.28318531, dtype=float64, weak_type=True),\n", |
106 | 135 | " 'b': DeviceArray([ 2., 0., -2.], dtype=float64),\n",
|
107 | 136 | " 'c': DeviceArray([[0., 2.],\n",
|
108 | 137 | " [2., 0.]], dtype=float64)}"
|
109 | 138 | ]
|
110 | 139 | },
|
111 |
| - "execution_count": 4, |
| 140 | + "execution_count": 5, |
112 | 141 | "metadata": {},
|
113 | 142 | "output_type": "execute_result"
|
114 | 143 | }
|
|
120 | 149 | },
|
121 | 150 | {
|
122 | 151 | "cell_type": "code",
|
123 |
| - "execution_count": 5, |
| 152 | + "execution_count": 6, |
124 | 153 | "id": "7aefa2e9",
|
125 | 154 | "metadata": {},
|
126 | 155 | "outputs": [
|
127 | 156 | {
|
128 | 157 | "name": "stdout",
|
129 | 158 | "output_type": "stream",
|
130 | 159 | "text": [
|
131 |
| - "8.25 ms ± 975 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" |
| 160 | + "11.5 ms ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" |
132 | 161 | ]
|
133 | 162 | }
|
134 | 163 | ],
|
|
138 | 167 | },
|
139 | 168 | {
|
140 | 169 | "cell_type": "code",
|
141 |
| - "execution_count": 6, |
| 170 | + "execution_count": 7, |
142 | 171 | "id": "dd8ffcc6",
|
143 | 172 | "metadata": {},
|
144 | 173 | "outputs": [
|
145 | 174 | {
|
146 | 175 | "name": "stdout",
|
147 | 176 | "output_type": "stream",
|
148 | 177 | "text": [
|
149 |
| - "10.7 µs ± 248 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" |
| 178 | + "17.2 µs ± 7.57 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" |
150 | 179 | ]
|
151 | 180 | }
|
152 | 181 | ],
|
|
155 | 184 | "%timeit jitted_gradient(start_params)"
|
156 | 185 | ]
|
157 | 186 | },
|
| 187 | + { |
| 188 | + "cell_type": "markdown", |
| 189 | + "id": "92f96dcc", |
| 190 | + "metadata": {}, |
| 191 | + "source": [ |
| 192 | + "## Solution, Task 2 (Windows):\n", |
| 193 | + "\n", |
| 194 | + "The analytical gradient of the function is given by:\n", |
| 195 | + "\n", |
| 196 | + "- $\\partial_a f(a, b, C) = 2 (a - \\pi)$\n", |
| 197 | + "- $\\partial_b f(a, b, C) = 2 (b - \\begin{pmatrix}0,1,2\\end{pmatrix}^\\top)$\n", |
| 198 | + "- $\\partial_C f(a, b, C) = 2 (C - I_2)$\n", |
| 199 | + "\n", |
| 200 | + "---\n", |
| 201 | + "\n", |
| 202 | + "- Implement the analytical gradient\n", |
| 203 | + " - return the gradient in the form of `{\"a\": ..., \"b\": ..., \"C\": ...}`" |
| 204 | + ] |
| 205 | + }, |
| 206 | + { |
| 207 | + "cell_type": "code", |
| 208 | + "execution_count": 8, |
| 209 | + "id": "2201091d", |
| 210 | + "metadata": {}, |
| 211 | + "outputs": [], |
| 212 | + "source": [ |
| 213 | + "def gradient(params):\n", |
| 214 | + " return {\n", |
| 215 | + " \"a\": 2 * (params[\"a\"] - np.pi),\n", |
| 216 | + " \"b\": 2 * (params[\"b\"] - np.array([0, 1, 2])),\n", |
| 217 | + " \"c\": 2 * (params[\"c\"] - np.eye(2))\n", |
| 218 | + " }" |
| 219 | + ] |
| 220 | + }, |
158 | 221 | {
|
159 | 222 | "cell_type": "markdown",
|
160 | 223 | "id": "e9e578b5",
|
161 | 224 | "metadata": {},
|
162 | 225 | "source": [
|
163 |
| - "## Task 3: Minimize\n", |
| 226 | + "## Solution, Task 3: Minimize\n", |
164 | 227 | "\n",
|
165 | 228 | "- Use estimagic to minimize the criterion\n",
|
166 | 229 | " - pass the gradient function you computed above to the minimize call.\n",
|
|
169 | 232 | },
|
170 | 233 | {
|
171 | 234 | "cell_type": "code",
|
172 |
| - "execution_count": 7, |
| 235 | + "execution_count": 9, |
173 | 236 | "id": "f23ead7a",
|
174 | 237 | "metadata": {},
|
175 | 238 | "outputs": [
|
176 | 239 | {
|
177 | 240 | "data": {
|
178 | 241 | "text/plain": [
|
179 |
| - "{'a': 3.1292550669508072,\n", |
180 |
| - " 'b': DeviceArray([-4.86427306e-06, 1.00000000e+00, 1.99999782e+00], dtype=float64),\n", |
181 |
| - " 'c': DeviceArray([[ 1.00000000e+00, -4.86427306e-06],\n", |
182 |
| - " [-4.86427306e-06, 1.00000000e+00]], dtype=float64)}" |
| 242 | + "{'a': 3.141592653589793,\n", |
| 243 | + " 'b': DeviceArray([3.33066907e-16, 1.00000000e+00, 2.00000000e+00], dtype=float64),\n", |
| 244 | + " 'c': DeviceArray([[1.00000000e+00, 3.33066907e-16],\n", |
| 245 | + " [3.33066907e-16, 1.00000000e+00]], dtype=float64)}" |
183 | 246 | ]
|
184 | 247 | },
|
185 |
| - "execution_count": 7, |
| 248 | + "execution_count": 9, |
186 | 249 | "metadata": {},
|
187 | 250 | "output_type": "execute_result"
|
188 | 251 | }
|
|
197 | 260 | "\n",
|
198 | 261 | "res.params"
|
199 | 262 | ]
|
| 263 | + }, |
| 264 | + { |
| 265 | + "cell_type": "code", |
| 266 | + "execution_count": 10, |
| 267 | + "id": "1ef9fc6e", |
| 268 | + "metadata": {}, |
| 269 | + "outputs": [ |
| 270 | + { |
| 271 | + "data": { |
| 272 | + "text/plain": [ |
| 273 | + "{'a': 3.141592653589793,\n", |
| 274 | + " 'b': array([3.33066907e-16, 1.00000000e+00, 2.00000000e+00]),\n", |
| 275 | + " 'c': array([[1.00000000e+00, 3.33066907e-16],\n", |
| 276 | + " [3.33066907e-16, 1.00000000e+00]])}" |
| 277 | + ] |
| 278 | + }, |
| 279 | + "execution_count": 10, |
| 280 | + "metadata": {}, |
| 281 | + "output_type": "execute_result" |
| 282 | + } |
| 283 | + ], |
| 284 | + "source": [ |
| 285 | + "res = em.minimize(\n", |
| 286 | + " criterion=criterion_windows,\n", |
| 287 | + " derivative=gradient,\n", |
| 288 | + " params=start_params_windows,\n", |
| 289 | + " algorithm=\"scipy_lbfgsb\",\n", |
| 290 | + ")\n", |
| 291 | + "\n", |
| 292 | + "res.params" |
| 293 | + ] |
200 | 294 | }
|
201 | 295 | ],
|
202 | 296 | "metadata": {
|
|
0 commit comments