|
11 | 11 | "**Update:** August 21, 2024\n",
|
12 | 12 | "\n",
|
13 | 13 | "**Implementation:** PAIN architecture presented in:\n",
|
14 |
| - "**[IEEE publication](https://ieeexplore.ieee.org/abstract/document/8682767)**" |
| 14 | + "**[IEEE publication](https://ieeexplore.ieee.org/abstract/document/8682767)**\n", |
| 15 | + "\n", |
| 16 | + "***NOTE:*** Code is setup to reconstruct 7x7 compressed noisy images. Training parameters may need to be adjusted for 4x4, 14x14, and 28x28 cases.\n", |
| 17 | + "\n", |
| 18 | + "Python Version: 3.11.5 \n", |
| 19 | + "Jupyer Notebook Version: 6.5.4" |
15 | 20 | ]
|
16 | 21 | },
|
17 | 22 | {
|
|
55 | 60 | "\n",
|
56 | 61 | "# Dimension of the compressed/noisy images (width=height) \n",
|
57 | 62 | "# cmp_dim = 4:(for 4x4), 7:(7x7), 14:(14x14), or 28:(28x284)\n",
|
58 |
| - "cmp_dim = 14\n", |
| 63 | + "cmp_dim = 7\n", |
59 | 64 | "\n",
|
60 | 65 | "# Dimension of output, original are 28 x 28\n",
|
61 |
| - "out_dim = 28 " |
| 66 | + "out_dim = 28 \n", |
| 67 | + "\n", |
| 68 | + "# Number of training epochs\n", |
| 69 | + "num_epochs = 25\n", |
| 70 | + "\n", |
| 71 | + "# Size of training batch sizes\n", |
| 72 | + "batch_size = 250" |
62 | 73 | ]
|
63 | 74 | },
|
64 | 75 | {
|
|
122 | 133 | "# Create compressed noisy data\n",
|
123 | 134 | "def create_training_test_data(clean_train, clean_test, cmp_dim, out_dim):\n",
|
124 | 135 | " \n",
|
125 |
| - " # Step 0: Create Training and Validation Sets\n", |
126 |
| - " clean_test, clean_valid = split_dataset_rnd(clean_test)\n", |
127 |
| - " \n", |
128 |
| - " # Step 1: Compress images using median in sliding window\n", |
| 136 | + " # Compress images using median in sliding window\n", |
129 | 137 | " cmp_train = down_sample_list(clean_train, cmp_dim)\n",
|
130 | 138 | " cmp_test = down_sample_list(clean_test, cmp_dim)\n",
|
131 |
| - " cmp_valid = down_sample_list(clean_valid, cmp_dim)\n", |
132 | 139 | "\n",
|
133 |
| - " # Step 2: Add Poisson noise to compressed images\n", |
| 140 | + " # Add Poisson noise to compressed images\n", |
134 | 141 | " noisy_train_ = np.random.poisson(lam=cmp_train)\n",
|
135 | 142 | " noisy_test_ = np.random.poisson(lam=cmp_test)\n",
|
136 |
| - " noisy_valid_ = np.random.poisson(lam=cmp_valid)\n", |
137 | 143 | " \n",
|
| 144 | + " # Scale to [0,255]\n", |
138 | 145 | " noisy_train = np.clip(noisy_train_,0,255)\n",
|
139 | 146 | " noisy_test = np.clip(noisy_test_,0,255)\n",
|
140 |
| - " noisy_valid = np.clip(noisy_valid_,0,255)\n", |
141 | 147 | " \n",
|
142 |
| - " # Step 3: Reshape Input Arrays and scale from [0, 255] to [0,1]\n", |
| 148 | + " # Reshape Compressed/Noisy Arrays and scale to [0,1]\n", |
143 | 149 | " noisy_train = np.array([matrix.reshape(cmp_dim**2,) for matrix in noisy_train/255])\n",
|
144 | 150 | " noisy_test = np.array([matrix.reshape(cmp_dim**2,) for matrix in noisy_test/255])\n",
|
145 |
| - " noisy_valid = np.array([matrix.reshape(cmp_dim**2,) for matrix in noisy_valid/255])\n", |
146 |
| - " \n", |
147 |
| - " # Step 4: Reshape Expected Ouput Arraus and scale from [0, 255] to [0,1]\n", |
148 |
| - " clean_train = np.array([matrix.reshape(out_dim**2,) for matrix in clean_train/255])\n", |
149 |
| - " clean_test = np.array([matrix.reshape(out_dim**2,) for matrix in clean_test/255])\n", |
150 |
| - " clean_valid = np.array([matrix.reshape(out_dim**2,) for matrix in clean_valid/255])\n", |
151 | 151 | " \n",
|
152 |
| - " # Step 5: Return training data\n", |
153 |
| - " return [clean_train, clean_test, clean_valid, noisy_train, noisy_test, noisy_valid]\n", |
| 152 | + " # Return training data\n", |
| 153 | + " return noisy_train, noisy_test\n", |
154 | 154 | "\n",
|
155 | 155 | "# PAIN architecture builder function\n",
|
156 | 156 | "def build_PAIN(in_dim, out_dim, enc_dim = 256):\n",
|
|
188 | 188 | "metadata": {},
|
189 | 189 | "outputs": [],
|
190 | 190 | "source": [
|
| 191 | + "# Check compression dimention variable:\n", |
| 192 | + "if not (cmp_dim in [4, 7, 14, 28]):\n", |
| 193 | + " raise ValueError(f'cmp_dim = {cmp_dim}, not handled. The cmp_dim value must be 4, 7, 14, or 28.')\n", |
| 194 | + "\n", |
191 | 195 | "directory = 'training_data'\n",
|
192 | 196 | "\n",
|
193 | 197 | "# If the directory does not exist, create it\n",
|
194 | 198 | "if not os.path.exists(directory):\n",
|
195 | 199 | " os.makedirs(directory)\n",
|
196 | 200 | "\n",
|
197 | 201 | "# Load the MNIST Dataset\n",
|
198 |
| - "file = f'mnist_training_data_cmp' # Do not modify\n", |
199 |
| - "\n", |
200 |
| - "# If the training data does not exist, create it\n", |
201 |
| - "if not os.path.exists(f'{directory}/{file}_{cmp_dim}x{cmp_dim}.h5'):\n", |
202 |
| - " # Load MNIST\n", |
203 |
| - " (clean_train, _), (clean_test, _) = tf.keras.datasets.mnist.load_data()\n", |
| 202 | + "dat_file = f'{directory}/mnist_{cmp_dim}x{cmp_dim}_train.h5'\n", |
204 | 203 | " \n",
|
205 |
| - " # Create compressed/noisy data, test set, and validation set\n", |
206 |
| - " train_data = create_training_test_data(clean_train, clean_test, cmp_dim, out_dim)\n", |
207 |
| - " \n", |
208 |
| - " # MNIST data\n", |
209 |
| - " clean_train = train_data[0]\n", |
210 |
| - " clean_test = train_data[1]\n", |
211 |
| - " clean_valid = train_data[2]\n", |
212 |
| - " \n", |
213 |
| - " # Compressed noisy signals\n", |
214 |
| - " noisy_train = train_data[3]\n", |
215 |
| - " noisy_test = train_data[4]\n", |
216 |
| - " noisy_valid = train_data[5]\n", |
217 |
| - " \n", |
218 |
| - " # Save with compression\n", |
219 |
| - " with h5py.File(f'{directory}/{file}_{cmp_dim}x{cmp_dim}.h5', 'w') as f:\n", |
220 |
| - " # Save original images\n", |
221 |
| - " f.create_dataset('clean_train', data=clean_train, compression='gzip')\n", |
222 |
| - " f.create_dataset('clean_test', data=clean_test, compression='gzip')\n", |
223 |
| - " f.create_dataset('clean_valid', data=clean_valid, compression='gzip')\n", |
| 204 | + "# Load MNIST\n", |
| 205 | + "(clean_train, _), (clean_test, _) = tf.keras.datasets.mnist.load_data()\n", |
| 206 | + "\n", |
| 207 | + "# If the compressed/noisy data does not exist, create it\n", |
| 208 | + "if not os.path.exists(dat_file):\n", |
| 209 | + " # Create compressed/noisy training/test data\n", |
| 210 | + " noisy_train, noisy_test = create_training_test_data(clean_train, clean_test, cmp_dim, out_dim)\n", |
224 | 211 | " \n",
|
| 212 | + " # Save compressed/noisy training/test data\n", |
| 213 | + " with h5py.File(dat_file, 'w') as f:\n", |
225 | 214 | " # Save compressed noisy images\n",
|
226 | 215 | " f.create_dataset('noisy_train', data=noisy_train, compression='gzip')\n",
|
227 | 216 | " f.create_dataset('noisy_test', data=noisy_test, compression='gzip')\n",
|
228 |
| - " f.create_dataset('noisy_valid', data=noisy_valid, compression='gzip')\n", |
229 | 217 | " \n",
|
230 |
| - "else: # If data exists, load it\n", |
231 |
| - " with h5py.File(f'{directory}/{file}_{cmp_dim}x{cmp_dim}.h5', 'r') as dat_file:\n", |
232 |
| - "\n", |
233 |
| - " # Load original MNIST images\n", |
234 |
| - " clean_train = dat_file['clean_train'][:]\n", |
235 |
| - " clean_test = dat_file['clean_test'][:]\n", |
236 |
| - " clean_valid = dat_file['clean_valid'][:]\n", |
237 |
| - "\n", |
238 |
| - " # Load compressed noisy images\n", |
| 218 | + "else: \n", |
| 219 | + " # Load compressed/noisy training data\n", |
| 220 | + " with h5py.File(dat_file, 'r') as dat_file:\n", |
| 221 | + " # Load compressed noisy training/test images\n", |
239 | 222 | " noisy_train = dat_file['noisy_train'][:]\n",
|
240 | 223 | " noisy_test = dat_file['noisy_test'][:]\n",
|
241 |
| - " noisy_valid = dat_file['noisy_valid'][:]" |
| 224 | + " \n", |
| 225 | + "# Prepare original mnist data for model training\n", |
| 226 | + "clean_train = np.array([matrix.reshape(out_dim**2,) for matrix in clean_train/255])\n", |
| 227 | + "clean_test = np.array([matrix.reshape(out_dim**2,) for matrix in clean_test/255])" |
242 | 228 | ]
|
243 | 229 | },
|
244 | 230 | {
|
|
254 | 240 | "metadata": {},
|
255 | 241 | "outputs": [],
|
256 | 242 | "source": [
|
257 |
| - "# Display: training set\n", |
258 |
| - "#######################\n", |
259 |
| - "\n", |
260 | 243 | "# Create a 2 by 4 subplot handle\n",
|
261 | 244 | "fig, axes = plt.subplots(2, 4, figsize=(9, 4.5))\n",
|
262 | 245 | "axes = axes.flatten()\n",
|
|
311 | 294 | "# Create & Compile the PAIN model\n",
|
312 | 295 | "PAIN = build_PAIN(in_dim=cmp_dim, out_dim=out_dim)\n",
|
313 | 296 | "\n",
|
314 |
| - "# Create an RMSProp optimizer with a specific learning rate\n", |
315 |
| - "RMSp = tf.keras.optimizers.RMSprop(learning_rate=0.05)\n", |
| 297 | + "# Create an optimizer with a specific learning rate\n", |
| 298 | + "tf_opt = tf.keras.optimizers.Adam(learning_rate=0.01)\n", |
316 | 299 | "\n",
|
317 | 300 | "# Compile the model\n",
|
318 |
| - "PAIN.compile(optimizer=RMSp, loss='mean_squared_error')\n", |
| 301 | + "PAIN.compile(optimizer=tf_opt, loss='mean_squared_error')\n", |
319 | 302 | "\n",
|
320 | 303 | "# Train model and saving fitting history\n",
|
321 |
| - "fit_history = PAIN.fit(noisy_train, clean_train, epochs=120, batch_size=250, validation_data=(noisy_test, clean_test))" |
| 304 | + "fit_history = PAIN.fit(noisy_train, clean_train, epochs=num_epochs, batch_size=batch_size, validation_split=0.2)" |
322 | 305 | ]
|
323 | 306 | },
|
324 | 307 | {
|
|
403 | 386 | "axes[8].set_ylabel(f'Original\\n{out_dim} x {out_dim}\\nMNIST',fontdict={'fontsize': 12, 'fontfamily': 'serif'})\n",
|
404 | 387 | "\n",
|
405 | 388 | "# Adjust layout to decrease padding between subplots\n",
|
406 |
| - "plt.subplots_adjust(wspace=0.1, hspace=0.00)\n", |
| 389 | + "plt.subplots_adjust(wspace=0.1, hspace=0)\n", |
407 | 390 | "\n",
|
408 | 391 | "# Save results \n",
|
409 | 392 | "plt.savefig(f'{imdir}/{imfile}_{cmp_dim}x{cmp_dim}_train_PAIN.png')\n",
|
|
416 | 399 | "cell_type": "markdown",
|
417 | 400 | "metadata": {},
|
418 | 401 | "source": [
|
419 |
| - "## Apply PAIN to Validation Set and Display Output" |
| 402 | + "## Apply PAIN to Test Set and Display Output" |
420 | 403 | ]
|
421 | 404 | },
|
422 | 405 | {
|
|
425 | 408 | "metadata": {},
|
426 | 409 | "outputs": [],
|
427 | 410 | "source": [
|
428 |
| - "# Apply PAIN to all validation data inputs\n", |
429 |
| - "pred_valid_out = PAIN.predict(noisy_valid[0:101])" |
| 411 | + "# Apply PAIN to all test data inputs\n", |
| 412 | + "pred_test_out = PAIN.predict(noisy_test[0:101])" |
430 | 413 | ]
|
431 | 414 | },
|
432 | 415 | {
|
|
440 | 423 | "axes = axes.flatten()\n",
|
441 | 424 | "\n",
|
442 | 425 | "# Add title\n",
|
443 |
| - "fig.suptitle('Application of PAIN Architecture\\n(MNIST Validation Set)',fontsize=20,fontweight='bold', fontfamily='serif')\n", |
| 426 | + "fig.suptitle('Application of PAIN Architecture\\n(MNIST Test Set)',fontsize=20,fontweight='bold', fontfamily='serif')\n", |
444 | 427 | "\n",
|
445 |
| - "# Shift window through validation dataset \n", |
| 428 | + "# Shift window through test dataset \n", |
446 | 429 | "sft_idx = 0\n",
|
447 | 430 | "\n",
|
448 | 431 | "# Plot noisy and decompressed realizations in each subplot\n",
|
449 | 432 | "for idx in range(4):\n",
|
450 | 433 | " # Plot noisy\n",
|
451 |
| - " axes[idx].imshow(noisy_valid[idx+sft_idx].reshape(cmp_dim,cmp_dim),cmap='gray')\n", |
| 434 | + " axes[idx].imshow(noisy_test[idx+sft_idx].reshape(cmp_dim,cmp_dim),cmap='gray')\n", |
452 | 435 | " axes[idx].set_xticks([]) # Remove xticklabels\n",
|
453 | 436 | " axes[idx].set_yticks([]) # Remove yticklabels\n",
|
454 | 437 | " axes[idx].set_xlabel('⇩',fontdict={'fontsize': 25, 'fontweight': 'bold', 'fontfamily': 'serif', 'color':'blue'})\n",
|
455 | 438 | " \n",
|
456 | 439 | " # Plot decompressed with PAIN\n",
|
457 |
| - " axes[idx+4].imshow(pred_valid_out[idx+sft_idx].reshape(28,28),cmap='gray')\n", |
| 440 | + " axes[idx+4].imshow(pred_test_out[idx+sft_idx].reshape(28,28),cmap='gray')\n", |
458 | 441 | " axes[idx+4].set_xticks([]) # Remove xticklabels\n",
|
459 | 442 | " axes[idx+4].set_yticks([]) # Remove yticklabels\n",
|
460 | 443 | " \n",
|
461 | 444 | " # Plot original\n",
|
462 |
| - " axes[idx+8].imshow(clean_valid[idx+sft_idx].reshape(28,28),cmap='gray')\n", |
| 445 | + " axes[idx+8].imshow(clean_test[idx+sft_idx].reshape(28,28),cmap='gray')\n", |
463 | 446 | " axes[idx+8].set_xticks([]) # Remove xticklabels\n",
|
464 | 447 | " axes[idx+8].set_yticks([]) # Remove yticklabels\n",
|
465 | 448 | " \n",
|
|
469 | 452 | "axes[8].set_ylabel(f'(Original)\\n{out_dim} x {out_dim}\\nMNIST',fontdict={'fontsize': 12, 'fontfamily': 'serif'})\n",
|
470 | 453 | "\n",
|
471 | 454 | "# Adjust layout to decrease padding between subplots\n",
|
472 |
| - "plt.subplots_adjust(wspace=0.1, hspace=0.25)\n", |
| 455 | + "plt.subplots_adjust(wspace=0.1, hspace=0)\n", |
473 | 456 | "\n",
|
474 | 457 | "# Save results \n",
|
475 |
| - "plt.savefig(f'{imdir}/{imfile}_{cmp_dim}x{cmp_dim}_valid_PAIN.png')\n", |
| 458 | + "plt.savefig(f'{imdir}/{imfile}_{cmp_dim}x{cmp_dim}_test_PAIN.png')\n", |
476 | 459 | "\n",
|
477 | 460 | "# Display the figure\n",
|
478 | 461 | "plt.show()"
|
|
0 commit comments